use std::{
io,
thread,
str::FromStr,
net::SocketAddr,
sync::mpsc,
};
use futures::TryFutureExt;
use tokio::net::TcpListener;
use hyper::{
Body,
Method,
Request,
Response,
StatusCode,
body::HttpBody,
server::{
conn::Http,
accept::from_stream,
},
service::{make_service_fn, service_fn},
};
use ciph::salsa::{
Psk, Connector, Acceptor, Randomness, TcpListenAcceptor, HyperSalsaConnector
};
const PSK_B64: &str = include_str!("test.psk");
async fn echo(req: Request
) -> Result, hyper::Error> {
match (req.method(), req.uri().path()) {
(&Method::POST, "/") => Ok(Response::new(req.into_body())),
_ => {
let mut not_found = Response::default();
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
},
}
}
fn spawn_localhost_server(psk: Psk, randomness: Randomness) -> SocketAddr {
let (addr_send, addr_recv) = mpsc::channel();
thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build()
.unwrap();
let acceptor = Acceptor::new(psk, randomness);
let server_fut = async move {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(&addr).await?;
let bound_addr = listener
.local_addr()
.unwrap();
addr_send
.send(bound_addr)
.unwrap();
let acceptor = TcpListenAcceptor::new(acceptor, listener);
let incoming = from_stream(acceptor);
let service = make_service_fn(|_| async {
Ok::<_, hyper::Error>(service_fn(echo))
});
let server = hyper::server::Builder::new(incoming, Http::new())
.serve(service);
server
.await
.map_err(|e| io::Error::new(io::ErrorKind::Other, Box::new(e)))?;
Ok(())
}
.unwrap_or_else(|err: io::Error| eprintln!("Server: {:?}", err));
rt.block_on(server_fut);
});
addr_recv
.recv()
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn salsa_hyper_cipher_works() {
const FILE: &'static [u8] = include_bytes!("../README.md");
let psk = Psk::from_str(PSK_B64).unwrap();
let addr = spawn_localhost_server(psk.clone(), Randomness::Entropy);
let connector = Connector::new(psk, Randomness::Entropy);
let uri_str = format!("http://localhost:{}/", addr.port());
let uri: hyper::Uri = uri_str
.parse()
.unwrap();
let body = hyper::Body::from(FILE);
let request = hyper::Request::builder()
.uri(uri)
.method("POST")
.body(body)
.unwrap();
let connector = HyperSalsaConnector::new_http(connector);
let client = hyper::Client::builder()
.build::<_, hyper::Body>(connector);
let mut response = client
.request(request)
.await
.unwrap();
let mut buffer: Vec = Vec::new();
while let Some(chunk) = response.data().await {
buffer.extend(chunk.unwrap());
}
assert_eq!(buffer.as_slice(), FILE);
}