use futures::join; use lazy_static::lazy_static; use native_tls_gmssl::{Certificate, Identity}; use std::{fs, io::Error, path::PathBuf, process::Command}; use tokio::{ io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}, net::{TcpListener, TcpStream}, }; use tokio_tls_gmssl::{TlsAcceptor, TlsConnector}; lazy_static! { static ref CERT_DIR: PathBuf = { if cfg!(unix) { let dir = tempfile::TempDir::new().unwrap(); let path = dir.path().to_str().unwrap(); Command::new("sh") .arg("-c") .arg(format!("./scripts/generate-certificate.sh {}", path)) .output() .expect("failed to execute process"); dir.into_path() } else { PathBuf::from("tests") } }; } #[tokio::test] async fn client_to_server() { let srv = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = srv.local_addr().unwrap(); let (server_tls, client_tls) = context(); // Create a future to accept one socket, connect the ssl stream, and then // read all the data from it. let server = async move { let (socket, _) = srv.accept().await.unwrap(); let mut socket = server_tls.accept(socket).await.unwrap(); // Verify access to all of the nested inner streams (e.g. so that peer // certificates can be accessed). This is just a compile check. let native_tls_stream: &native_tls_gmssl::TlsStream<_> = socket.get_ref(); let _peer_cert = native_tls_stream.peer_certificate().unwrap(); let allow_std_stream: &tokio_tls_gmssl::AllowStd<_> = native_tls_stream.get_ref(); let _tokio_tcp_stream: &tokio::net::TcpStream = allow_std_stream.get_ref(); let mut data = Vec::new(); socket.read_to_end(&mut data).await.unwrap(); data }; // Create a future to connect to our server, connect the ssl stream, and // then write a bunch of data to it. let client = async move { let socket = TcpStream::connect(&addr).await.unwrap(); let socket = client_tls.connect("foobar.com", socket).await.unwrap(); copy_data(socket).await }; // Finally, run everything! let (data, _) = join!(server, client); // assert_eq!(amt, AMT); assert!(data == vec![9; AMT]); } #[tokio::test] async fn server_to_client() { // Create a server listening on a port, then figure out what that port is let srv = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = srv.local_addr().unwrap(); let (server_tls, client_tls) = context(); let server = async move { let (socket, _) = srv.accept().await.unwrap(); let socket = server_tls.accept(socket).await.unwrap(); copy_data(socket).await }; let client = async move { let socket = TcpStream::connect(&addr).await.unwrap(); let mut socket = client_tls.connect("foobar.com", socket).await.unwrap(); let mut data = Vec::new(); socket.read_to_end(&mut data).await.unwrap(); data }; // Finally, run everything! let (_, data) = join!(server, client); assert!(data == vec![9; AMT]); } #[tokio::test] async fn one_byte_at_a_time() { const AMT: usize = 1024; let srv = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = srv.local_addr().unwrap(); let (server_tls, client_tls) = context(); let server = async move { let (socket, _) = srv.accept().await.unwrap(); let mut socket = server_tls.accept(socket).await.unwrap(); let mut amt = 0; for b in std::iter::repeat(9).take(AMT) { let data = [b as u8]; socket.write_all(&data).await.unwrap(); amt += 1; } amt }; let client = async move { let socket = TcpStream::connect(&addr).await.unwrap(); let mut socket = client_tls.connect("foobar.com", socket).await.unwrap(); let mut data = Vec::new(); loop { let mut buf = [0; 1]; match socket.read_exact(&mut buf).await { Ok(_) => data.extend_from_slice(&buf), Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break, Err(err) => panic!("{}", err), } } data }; let (amt, data) = join!(server, client); assert_eq!(amt, AMT); assert!(data == vec![9; AMT as usize]); } fn context() -> (TlsAcceptor, TlsConnector) { let pkcs12 = fs::read(CERT_DIR.join("identity.p12")).unwrap(); let der = fs::read(CERT_DIR.join("root-ca.der")).unwrap(); let identity = Identity::from_pkcs12(&pkcs12, "mypass").unwrap(); let acceptor = native_tls_gmssl::TlsAcceptor::builder(identity).build().unwrap(); let cert = Certificate::from_der(&der).unwrap(); let connector = native_tls_gmssl::TlsConnector::builder() .add_root_certificate(cert) .build() .unwrap(); (acceptor.into(), connector.into()) } const AMT: usize = 128 * 1024; async fn copy_data(mut w: W) -> Result { let mut data = vec![9; AMT as usize]; let mut amt = 0; while !data.is_empty() { let written = w.write(&data).await?; if written <= data.len() { amt += written; data.resize(data.len() - written, 0); } else { w.write_all(&data).await?; amt += data.len(); break; } println!("remaining: {}", data.len()); } Ok(amt) }