use std::{path::PathBuf, sync::Arc}; use mio::net::{TcpListener, TcpStream}; use std::collections::HashMap; use std::fs; use std::io; use std::io::{BufReader, Read, Write}; use std::net; use rustls::{NoClientAuth, Session}; // Token for our listening socket. const LISTENER: mio::Token = mio::Token(0); // Which mode the server operates in. #[derive(Clone)] enum ServerMode { /// Write back received bytes Echo, } /// This binds together a TCP listening socket, some outstanding /// connections, and a TLS server configuration. struct TlsServer { server: TcpListener, connections: HashMap, next_id: usize, tls_config: Arc, mode: ServerMode, } impl TlsServer { fn new(server: TcpListener, mode: ServerMode, cfg: Arc) -> TlsServer { TlsServer { server, connections: HashMap::new(), next_id: 2, tls_config: cfg, mode, } } fn accept(&mut self, registry: &mio::Registry) -> Result<(), io::Error> { loop { match self.server.accept() { Ok((socket, addr)) => { log::debug!("Accepting new connection from {:?}", addr); let tls_session = rustls::ServerSession::new(&self.tls_config); let mode = self.mode.clone(); let token = mio::Token(self.next_id); self.next_id += 1; let mut connection = Connection::new(socket, token, mode, tls_session); connection.register(registry); self.connections.insert(token, connection); } Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Ok(()), Err(err) => { println!( "encountered error while accepting connection; err={:?}", err ); return Err(err); } } } } fn conn_event(&mut self, registry: &mio::Registry, event: &mio::event::Event) { let token = event.token(); if self.connections.contains_key(&token) { self.connections .get_mut(&token) .unwrap() .ready(registry, event); if self.connections[&token].is_closed() { self.connections.remove(&token); } } } } /// This is a connection which has been accepted by the server, /// and is currently being served. /// /// It has a TCP-level stream, a TLS-level session, and some /// other state/metadata. struct Connection { socket: TcpStream, token: mio::Token, closing: bool, closed: bool, mode: ServerMode, tls_session: rustls::ServerSession, back: Option, } /// Open a plaintext TCP-level connection for forwarded connections. fn open_back(mode: &ServerMode) -> Option { match *mode { _ => None, } } /// This used to be conveniently exposed by mio: map EWOULDBLOCK /// errors to something less-errory. fn try_read(r: io::Result) -> io::Result> { match r { Ok(len) => Ok(Some(len)), Err(e) => { if e.kind() == io::ErrorKind::WouldBlock { Ok(None) } else { Err(e) } } } } impl Connection { fn new( socket: TcpStream, token: mio::Token, mode: ServerMode, tls_session: rustls::ServerSession, ) -> Connection { let back = open_back(&mode); Connection { socket, token, closing: false, closed: false, mode, tls_session, back, } } /// We're a connection, and we have something to do. fn ready(&mut self, registry: &mio::Registry, ev: &mio::event::Event) { // If we're readable: read some TLS. Then // see if that yielded new plaintext. Then // see if the backend is readable too. if ev.is_readable() { self.do_tls_read(); self.try_plain_read(); self.try_back_read(); } if ev.is_writable() { self.do_tls_write_and_handle_error(); } if self.closing { let _ = self.socket.shutdown(net::Shutdown::Both); self.close_back(); self.closed = true; self.deregister(registry); } else { self.reregister(registry); } } /// Close the backend connection for forwarded sessions. fn close_back(&mut self) { if self.back.is_some() { let back = self.back.as_mut().unwrap(); back.shutdown(net::Shutdown::Both).unwrap(); } self.back = None; } fn do_tls_read(&mut self) { // Read some TLS data. let rc = self.tls_session.read_tls(&mut self.socket); if rc.is_err() { let err = rc.unwrap_err(); if let io::ErrorKind::WouldBlock = err.kind() { return; } log::warn!("read error {:?}", err); self.closing = true; return; } if rc.unwrap() == 0 { log::debug!("eof"); self.closing = true; return; } // Process newly-received TLS messages. let processed = self.tls_session.process_new_packets(); if processed.is_err() { log::warn!("cannot process packet: {:?}", processed); // last gasp write to send any alerts self.do_tls_write_and_handle_error(); self.closing = true; return; } } fn try_plain_read(&mut self) { // Read and process all available plaintext. let mut buf = Vec::new(); let rc = self.tls_session.read_to_end(&mut buf); if rc.is_err() { log::warn!("plaintext read failed: {:?}", rc); self.closing = true; return; } if !buf.is_empty() { log::debug!("plaintext read {:?}", buf.len()); self.incoming_plaintext(&buf); } } fn try_back_read(&mut self) { if self.back.is_none() { return; } // Try a non-blocking read. let mut buf = [0u8; 1024]; let back = self.back.as_mut().unwrap(); let rc = try_read(back.read(&mut buf)); if rc.is_err() { log::warn!("backend read failed: {:?}", rc); self.closing = true; return; } let maybe_len = rc.unwrap(); // If we have a successful but empty read, that's an EOF. // Otherwise, we shove the data into the TLS session. match maybe_len { Some(len) if len == 0 => { log::debug!("back eof"); self.closing = true; } Some(len) => { self.tls_session.write_all(&buf[..len]).unwrap(); } None => {} }; } /// Process some amount of received plaintext. fn incoming_plaintext(&mut self, buf: &[u8]) { match self.mode { ServerMode::Echo => { self.tls_session.write_all(buf).unwrap(); } } } fn tls_write(&mut self) -> io::Result { self.tls_session.write_tls(&mut self.socket) } fn do_tls_write_and_handle_error(&mut self) { let rc = self.tls_write(); if rc.is_err() { log::warn!("write failed {:?}", rc); self.closing = true; return; } } fn register(&mut self, registry: &mio::Registry) { let event_set = self.event_set(); registry .register(&mut self.socket, self.token, event_set) .unwrap(); if self.back.is_some() { registry .register( self.back.as_mut().unwrap(), self.token, mio::Interest::READABLE, ) .unwrap(); } } fn reregister(&mut self, registry: &mio::Registry) { let event_set = self.event_set(); registry .reregister(&mut self.socket, self.token, event_set) .unwrap(); } fn deregister(&mut self, registry: &mio::Registry) { registry.deregister(&mut self.socket).unwrap(); if self.back.is_some() { registry.deregister(self.back.as_mut().unwrap()).unwrap(); } } /// What IO events we're currently waiting for, /// based on wants_read/wants_write. fn event_set(&self) -> mio::Interest { let rd = self.tls_session.wants_read(); let wr = self.tls_session.wants_write(); if rd && wr { mio::Interest::READABLE | mio::Interest::WRITABLE } else if wr { mio::Interest::WRITABLE } else { mio::Interest::READABLE } } fn is_closed(&self) -> bool { self.closed } } fn load_certs(filename: &PathBuf) -> Vec { let certfile = fs::File::open(filename).expect("cannot open certificate file"); let mut reader = BufReader::new(certfile); rustls_pemfile::certs(&mut reader) .unwrap() .iter() .map(|v| rustls::Certificate(v.clone())) .collect() } fn load_private_key(filename: &PathBuf) -> rustls::PrivateKey { let keyfile = fs::File::open(filename).expect("cannot open private key file"); let mut reader = BufReader::new(keyfile); loop { match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") { Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key), Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key), None => break, _ => {} } } panic!( "no keys found in {:?} (encrypted keys not supported)", filename ); } pub fn run(mut listener: TcpListener) { let client_auth = NoClientAuth::new(); let suites = rustls::ALL_CIPHERSUITES.to_vec(); let versions = vec![rustls::ProtocolVersion::TLSv1_3]; let test_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests"); let certs = load_certs(&test_dir.join("data").join("server-cert.pem")); let privkey = load_private_key(&test_dir.join("data").join("server-key.pem")); let mut config = rustls::ServerConfig::new(client_auth); config.ciphersuites = suites; config.versions = versions; config.set_single_cert(certs, privkey).unwrap(); let mut poll = mio::Poll::new().unwrap(); poll.registry() .register(&mut listener, LISTENER, mio::Interest::READABLE) .unwrap(); let mut tlsserv = TlsServer::new(listener, ServerMode::Echo, Arc::new(config)); let mut events = mio::Events::with_capacity(256); loop { poll.poll(&mut events, None).unwrap(); for event in events.iter() { match event.token() { LISTENER => { tlsserv .accept(poll.registry()) .expect("error accepting socket"); } _ => tlsserv.conn_event(poll.registry(), &event), } } } }