// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use s2n_tls_tokio::{TlsAcceptor, TlsConnector}; use std::{io, task::Poll::*}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; pub mod common; const TEST_DATA: &[u8] = "hello world".as_bytes(); // The maximum TLS record payload is 2^14 bytes. // Send more to ensure multiple records. const LARGE_TEST_DATA: &[u8] = &[5; (1 << 15)]; #[tokio::test] async fn send_and_recv_basic() -> Result<(), Box<dyn std::error::Error>> { let (server_stream, client_stream) = common::get_streams().await?; let connector = TlsConnector::new(common::client_config()?.build()?); let acceptor = TlsAcceptor::new(common::server_config()?.build()?); let (mut client, mut server) = common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?; client.write_all(TEST_DATA).await?; let mut received = [0; TEST_DATA.len()]; assert_eq!(server.read_exact(&mut received).await?, TEST_DATA.len()); assert_eq!(TEST_DATA, received); Ok(()) } #[tokio::test] async fn send_and_recv_into_vec() -> Result<(), Box<dyn std::error::Error>> { let (server_stream, client_stream) = common::get_streams().await?; let connector = TlsConnector::new(common::client_config()?.build()?); let acceptor = TlsAcceptor::new(common::server_config()?.build()?); let (mut client, mut server) = common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?; client.write_all(TEST_DATA).await?; let mut received = vec![]; while received.len() < TEST_DATA.len() { let bytes_read = server.read_buf(&mut received).await?; assert!(bytes_read > 0); } assert_eq!(TEST_DATA, received); Ok(()) } #[tokio::test] async fn send_and_recv_multiple_records() -> Result<(), Box<dyn std::error::Error>> { let (server_stream, client_stream) = common::get_streams().await?; let connector = TlsConnector::new(common::client_config()?.build()?); let acceptor = TlsAcceptor::new(common::server_config()?.build()?); let (mut client, mut server) = common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?; let mut received = [0; LARGE_TEST_DATA.len()]; let (_, read_size) = tokio::try_join!( client.write_all(LARGE_TEST_DATA), server.read_exact(&mut received) )?; assert_eq!(LARGE_TEST_DATA.len(), read_size); assert_eq!(LARGE_TEST_DATA, received); Ok(()) } #[tokio::test] async fn send_and_recv_split() -> Result<(), Box<dyn std::error::Error>> { let (server_stream, client_stream) = common::get_streams().await?; let connector = TlsConnector::new(common::client_config()?.build()?); let acceptor = TlsAcceptor::new(common::server_config()?.build()?); let (client, server) = common::run_negotiate(&connector, client_stream, &acceptor, server_stream).await?; let (mut client_read, mut client_write) = tokio::io::split(client); let (mut server_read, mut server_write) = tokio::io::split(server); let mut client_received = [0; LARGE_TEST_DATA.len()]; let mut server_received = [0; LARGE_TEST_DATA.len()]; let (_, _, client_bytes, server_bytes) = tokio::try_join!( client_write.write_all(LARGE_TEST_DATA), server_write.write_all(LARGE_TEST_DATA), client_read.read_exact(&mut client_received), server_read.read_exact(&mut server_received) )?; assert_eq!(client_bytes, LARGE_TEST_DATA.len()); assert_eq!(server_bytes, LARGE_TEST_DATA.len()); assert_eq!(LARGE_TEST_DATA, client_received); assert_eq!(LARGE_TEST_DATA, server_received); Ok(()) } #[tokio::test] async fn send_error() -> Result<(), Box<dyn std::error::Error>> { let client = TlsConnector::new(common::client_config()?.build()?); let server = TlsAcceptor::new(common::server_config()?.build()?); let (server_stream, client_stream) = common::get_streams().await?; let client_stream = common::TestStream::new(client_stream); let overrides = client_stream.overrides(); let (mut client, _) = common::run_negotiate(&client, client_stream, &server, server_stream).await?; // Setup write to fail overrides.next_write(Some(Box::new(|_, _, _| { Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset))) }))); // Verify write fails let result = client.write_all(TEST_DATA).await; assert!(result.is_err()); Ok(()) } #[tokio::test] async fn recv_error() -> Result<(), Box<dyn std::error::Error>> { let client = TlsConnector::new(common::client_config()?.build()?); let server = TlsAcceptor::new(common::server_config()?.build()?); let (server_stream, client_stream) = common::get_streams().await?; let client_stream = common::TestStream::new(client_stream); let overrides = client_stream.overrides(); let (mut client, _) = common::run_negotiate(&client, client_stream, &server, server_stream).await?; // Setup read to fail overrides.next_read(Some(Box::new(|_, _, _| { Ready(Err(io::Error::from(io::ErrorKind::ConnectionReset))) }))); // Verify read fails let mut received = [0; 1]; let result = client.read_exact(&mut received).await; assert!(result.is_err()); Ok(()) }