use std::{ io, thread, str::FromStr, net::SocketAddr, sync::mpsc, }; use futures::TryFutureExt; use tokio::net::TcpListener; use hyper::{ Body, Method, Request, Response, StatusCode, body::HttpBody, server::{ conn::Http, accept::from_stream, }, service::{make_service_fn, service_fn}, }; use ciph::salsa::{ Psk, Connector, Acceptor, Randomness, TcpListenAcceptor, HyperSalsaConnector }; const PSK_B64: &str = include_str!("test.psk"); async fn echo(req: Request) -> Result, hyper::Error> { match (req.method(), req.uri().path()) { (&Method::POST, "/") => Ok(Response::new(req.into_body())), _ => { let mut not_found = Response::default(); *not_found.status_mut() = StatusCode::NOT_FOUND; Ok(not_found) }, } } fn spawn_localhost_server(psk: Psk, randomness: Randomness) -> SocketAddr { let (addr_send, addr_recv) = mpsc::channel(); thread::spawn(move || { let rt = tokio::runtime::Builder::new_current_thread() .enable_io() .build() .unwrap(); let acceptor = Acceptor::new(psk, randomness); let server_fut = async move { let addr = SocketAddr::from(([127, 0, 0, 1], 0)); let listener = TcpListener::bind(&addr).await?; let bound_addr = listener .local_addr() .unwrap(); addr_send .send(bound_addr) .unwrap(); let acceptor = TcpListenAcceptor::new(acceptor, listener); let incoming = from_stream(acceptor); let service = make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(echo)) }); let server = hyper::server::Builder::new(incoming, Http::new()) .serve(service); server .await .map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)))?; Ok(()) } .unwrap_or_else(|err: io::Error| eprintln!("Server: {:?}", err)); rt.block_on(server_fut); }); addr_recv .recv() .unwrap() } #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn salsa_hyper_cipher_works() { const FILE: &'static [u8] = include_bytes!("../README.md"); let psk = Psk::from_str(PSK_B64).unwrap(); let addr = spawn_localhost_server(psk.clone(), Randomness::Entropy); let connector = Connector::new(psk, Randomness::Entropy); let uri_str = format!("http://localhost:{}/", addr.port()); let uri: hyper::Uri = uri_str .parse() .unwrap(); let body = hyper::Body::from(FILE); let request = hyper::Request::builder() .uri(uri) .method("POST") .body(body) .unwrap(); let connector = HyperSalsaConnector::new_http(connector); let client = hyper::Client::builder() .build::<_, hyper::Body>(connector); let mut response = client .request(request) .await .unwrap(); let mut buffer: Vec = Vec::new(); while let Some(chunk) = response.data().await { buffer.extend(chunk.unwrap()); } assert_eq!(buffer.as_slice(), FILE); }