use std::future::Future; use std::sync::RwLock; use std::time::Duration; use tonic::transport::{Channel, ClientTlsConfig, Uri}; use tonic::{Code, Status}; pub struct ChannelPool { channel: RwLock>, uri: Uri, grpc_timeout: Duration, connection_timeout: Duration, keep_alive_while_idle: bool, } impl ChannelPool { pub fn new( uri: Uri, grpc_timeout: Duration, connection_timeout: Duration, keep_alive_while_idle: bool, ) -> Self { Self { channel: RwLock::new(None), uri, grpc_timeout, connection_timeout, keep_alive_while_idle, } } async fn make_channel(&self) -> Result { let tls = match self.uri.scheme_str() { None => false, Some(schema) => match schema { "http" => false, "https" => true, _ => { return Err(Status::invalid_argument(format!( "Unsupported schema: {}", schema ))) } }, }; let endpoint = Channel::builder(self.uri.clone()) .timeout(self.grpc_timeout) .connect_timeout(self.connection_timeout) .keep_alive_while_idle(self.keep_alive_while_idle); let endpoint = if tls { let tls_config = ClientTlsConfig::new().with_native_roots(); endpoint .tls_config(tls_config) .map_err(|e| Status::internal(format!("Failed to create TLS config: {}", e)))? } else { endpoint }; let channel = endpoint .connect() .await .map_err(|e| Status::internal(format!("Failed to connect to {}: {:?}", self.uri, e)))?; let mut self_channel = self.channel.write().unwrap(); *self_channel = Some(channel.clone()); Ok(channel) } async fn get_channel(&self) -> Result { if let Some(channel) = &*self.channel.read().unwrap() { return Ok(channel.clone()); } let channel = self.make_channel().await?; Ok(channel) } pub async fn drop_channel(&self) { let mut channel = self.channel.write().unwrap(); *channel = None; } // Allow to retry request if channel is broken pub async fn with_channel>>( &self, f: impl Fn(Channel) -> O, allow_retry: bool, ) -> Result { let channel = self.get_channel().await?; let result: Result = f(channel).await; // Reconnect on failure to handle the case with domain name change. match result { Ok(res) => Ok(res), Err(err) => match err.code() { Code::Internal | Code::Unavailable | Code::Cancelled | Code::Unknown => { self.drop_channel().await; if allow_retry { let channel = self.get_channel().await?; Ok(f(channel).await?) } else { Err(err) } } _ => Err(err)?, }, } } } // The future returned by get_channel needs to be Send so that the client can be // used by external async functions. #[test] fn require_get_channel_fn_to_be_send() { fn require_send(_t: T) {} require_send(async { ChannelPool::new( Uri::from_static(""), Duration::from_millis(0), Duration::from_millis(0), false, ) .get_channel() .await .expect("get channel should not error"); }); }