use std::marker::PhantomData; use std::task::{Context, Poll}; use scrappy_codec::{AsyncRead, AsyncWrite}; use scrappy_service::{Service, ServiceFactory}; use scrappy_utils::counter::Counter; use futures::future::{self, FutureExt, LocalBoxFuture, TryFutureExt}; pub use native_tls::Error; pub use tokio_tls::{TlsAcceptor, TlsStream}; use crate::MAX_CONN_COUNTER; /// Support `SSL` connections via native-tls package /// /// `tls` feature enables `NativeTlsAcceptor` type pub struct NativeTlsAcceptor { acceptor: TlsAcceptor, io: PhantomData, } impl NativeTlsAcceptor where T: AsyncRead + AsyncWrite + Unpin, { /// Create `NativeTlsAcceptor` instance #[inline] pub fn new(acceptor: TlsAcceptor) -> Self { NativeTlsAcceptor { acceptor, io: PhantomData, } } } impl Clone for NativeTlsAcceptor { #[inline] fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), io: PhantomData, } } } impl ServiceFactory for NativeTlsAcceptor where T: AsyncRead + AsyncWrite + Unpin + 'static, { type Request = T; type Response = TlsStream; type Error = Error; type Service = NativeTlsAcceptorService; type Config = (); type InitError = (); type Future = future::Ready>; fn new_service(&self, _: ()) -> Self::Future { MAX_CONN_COUNTER.with(|conns| { future::ok(NativeTlsAcceptorService { acceptor: self.acceptor.clone(), conns: conns.clone(), io: PhantomData, }) }) } } pub struct NativeTlsAcceptorService { acceptor: TlsAcceptor, io: PhantomData, conns: Counter, } impl Clone for NativeTlsAcceptorService { fn clone(&self) -> Self { Self { acceptor: self.acceptor.clone(), io: PhantomData, conns: self.conns.clone(), } } } impl Service for NativeTlsAcceptorService where T: AsyncRead + AsyncWrite + Unpin + 'static, { type Request = T; type Response = TlsStream; type Error = Error; type Future = LocalBoxFuture<'static, Result, Error>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { if self.conns.available(cx) { Poll::Ready(Ok(())) } else { Poll::Pending } } fn call(&mut self, req: Self::Request) -> Self::Future { let guard = self.conns.get(); let this = self.clone(); async move { this.acceptor.accept(req).await } .map_ok(move |io| { // Required to preserve `CounterGuard` until `Self::Future` // is completely resolved. let _ = guard; io }) .boxed_local() } }