use anyhow::{Context, Result}; use bytes::{BufMut, BytesMut}; use futures::{SinkExt, StreamExt}; use rand::Rng; use snowflakes::{SnowFlakes, WinterFramed}; use std::net::Ipv4Addr as addr; use tokio::{ net::{TcpListener, TcpStream}, sync::oneshot, time::Instant, }; const PATTERN: &str = "Noise_NK_25519_AESGCM_SHA256"; const FLAKE_SIZE: usize = 1 << 24; async fn get_ready_server( port: u16, done: oneshot::Sender<()>, private_key: Vec, ) -> Result> { let listener = TcpListener::bind((addr::LOCALHOST, port)).await?; done.send(()).unwrap(); let (stream, _addr) = listener.accept().await?; let handshake = snow::Builder::new(PATTERN.parse()?) .local_private_key(&private_key) .build_responder()?; let mut winter_framed = WinterFramed::new(stream, handshake); let msg = winter_framed .next() .await .context("Read initial message failed.")? .context("Parse initial message failed")?; assert!(msg.is_empty()); winter_framed .send("".into()) .await .context("Send initial message failed.")?; Ok(winter_framed .into_snow_framed()? .into_snow_flakes(FLAKE_SIZE)) } async fn echo_server(port: u16, done: oneshot::Sender<()>, private_key: Vec) -> Result<()> { let mut snowflakes = get_ready_server(port, done, private_key).await?; loop { let msg = match snowflakes.next().await { Some(x) => x?, None => break, }; snowflakes.send(msg.freeze()).await?; } Ok(()) } async fn bench_server(port: u16, done: oneshot::Sender<()>, private_key: Vec) -> Result<()> { let mut snowflakes = get_ready_server(port, done, private_key).await?; loop { let _msg = match snowflakes.next().await { Some(x) => x?, None => break, }; } Ok(()) } async fn get_ready_client(port: u16, public_key: Vec) -> Result> { let stream = TcpStream::connect((addr::LOCALHOST, port)).await?; let handshake = snow::Builder::new(PATTERN.parse()?) .remote_public_key(&public_key) .build_initiator()?; let mut winter_framed = WinterFramed::new(stream, handshake); winter_framed .send("".into()) .await .context("Send initial message failed.")?; let msg = winter_framed .next() .await .context("Remote shutdown unexpectedly.")? .context("Get respond message failed.")?; assert_eq!(msg, b""); Ok(winter_framed .into_snow_framed()? .into_snow_flakes(FLAKE_SIZE)) } async fn client(port: u16, public_key: Vec) -> Result<()> { let mut snow_flakes = get_ready_client(port, public_key).await?; for _ in 0..100 { let number: u64 = rand::random(); let mut bytes = BytesMut::new(); bytes.put_slice(number.to_string().as_bytes()); snow_flakes.send(bytes.freeze()).await.unwrap(); let packet = snow_flakes .next() .await .context("Unexpected server shutdown") .unwrap() .context("Echo msg error") .unwrap(); assert_eq!(packet, number.to_string().into_bytes()); } Ok(()) } async fn laggy_client(port: u16, public_key: Vec) -> Result<()> { let mut snow_flakes = get_ready_client(port, public_key).await?; let mut expected_nums = vec![]; for _ in 0..100 { let number: u64 = rand::random(); let mut bytes = BytesMut::new(); bytes.put_slice(number.to_string().as_bytes()); snow_flakes.send(bytes.freeze()).await?; expected_nums.push(number); if rand::random() { for number in expected_nums.drain(..) { let packet = snow_flakes .next() .await .context("Unexpected server shutdown")? .context("Echo msg error")?; assert_eq!(packet, number.to_string().into_bytes()); } } } for number in expected_nums.into_iter() { let packet = snow_flakes .next() .await .context("Unexpected server shutdown")? .context("Echo msg error")?; assert_eq!(packet, number.to_string().into_bytes()); } Ok(()) } async fn strange_client(port: u16, public_key: Vec) -> Result<()> { let mut snow_flakes = get_ready_client(port, public_key).await?; let mut rng = rand::thread_rng(); let random: Vec = (0..FLAKE_SIZE).map(|_| rng.gen::()).collect(); for _ in 0..3 { snow_flakes.send(random.clone().into()).await.unwrap(); assert_eq!(snow_flakes.next().await.unwrap().unwrap(), random); } let random: Vec = (0..65535).map(|_| rng.gen::()).collect(); for _ in 0..100 { snow_flakes.send(random.clone().into()).await.unwrap(); assert_eq!(snow_flakes.next().await.unwrap().unwrap(), random); } for i in 0..100 { let random: Vec = (0..(65535 - i)).map(|_| rng.gen::()).collect(); snow_flakes.send(random.clone().into()).await.unwrap(); assert_eq!(snow_flakes.next().await.unwrap().unwrap(), random); } for i in 0..100 { let random: Vec = (0..i).map(|_| rng.gen::()).collect(); snow_flakes.send(random.clone().into()).await.unwrap(); assert_eq!(snow_flakes.next().await.unwrap().unwrap(), random); } Ok(()) } async fn bench_client(port: u16, public_key: Vec) -> Result<()> { const ROUND: usize = 100; let mut snow_flakes = get_ready_client(port, public_key).await?; let mut rng = rand::thread_rng(); let random: Vec = (0..FLAKE_SIZE).map(|_| rng.gen::()).collect(); let random = { let mut x = BytesMut::with_capacity(FLAKE_SIZE); x.put_slice(&random); x }; let random = random.freeze(); println!("begins"); let time = Instant::now(); for _ in 0..ROUND { snow_flakes.send(random.clone()).await.unwrap(); } let elapsed = time.elapsed(); let time = elapsed.as_secs_f32(); let total = random.len() * ROUND; println!( "time: {}, size_bytes: {}, {}MB/s", time, total, total as f32 / time / 1024. / 1024. ); Ok(()) } fn generate() -> snow::Keypair { snow::Builder::new(PATTERN.parse().unwrap()) .generate_keypair() .unwrap() } #[tokio::test] async fn normal_echo() -> Result<()> { let port = rand::random(); let (sender, receiver) = oneshot::channel(); let snow::Keypair { public, private } = generate(); let server_handle = tokio::spawn(async move { echo_server(port, sender, private).await.unwrap() }); receiver.await?; client(port, public).await?; server_handle.await?; Ok(()) } #[tokio::test] async fn laggy_echo() -> Result<()> { let port = rand::random(); let (sender, receiver) = oneshot::channel(); let snow::Keypair { public, private } = generate(); let server_handle = tokio::spawn(async move { echo_server(port, sender, private).await.unwrap() }); receiver.await?; laggy_client(port, public).await?; server_handle.await?; Ok(()) } #[tokio::test] async fn strange_echo() -> Result<()> { let port = rand::random(); let (sender, receiver) = oneshot::channel(); let snow::Keypair { public, private } = generate(); let server_handle = tokio::spawn(async move { echo_server(port, sender, private).await.unwrap() }); receiver.await?; strange_client(port, public).await?; server_handle.await?; Ok(()) } #[ignore] #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn benchmark() -> Result<()> { let port = rand::random(); let (sender, receiver) = oneshot::channel(); let snow::Keypair { public, private } = generate(); let server_handle = tokio::spawn(async move { bench_server(port, sender, private).await.unwrap() }); receiver.await?; bench_client(port, public).await?; server_handle.await?; Ok(()) }