// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 use crate::common::{echo::echo, InsecureAcceptAllCertificatesHandler}; use bytes::Bytes; use common::echo::serve_echo; use http::{Method, Request, Uri, Version}; use http_body_util::{BodyExt, Empty, Full}; use hyper::service::service_fn; use hyper_util::{ client::legacy::Client, rt::{TokioExecutor, TokioIo}, }; use s2n_tls::{ callbacks::{ClientHelloCallback, ConnectionFuture}, config, connection::Connection, security::DEFAULT_TLS13, }; use s2n_tls_hyper::{connector::HttpsConnector, error}; use std::{error::Error, pin::Pin, str::FromStr}; use tokio::{ net::TcpListener, task::{JoinHandle, JoinSet}, }; pub mod common; const TEST_DATA: &[u8] = "hello world".as_bytes(); // The maximum TLS record payload is 2^14 bytes. // Send more to ensure multiple records. const LARGE_TEST_DATA: &[u8] = &[5; (1 << 15)]; #[tokio::test] async fn test_get_request() -> Result<(), Box<dyn Error + Send + Sync>> { let config = common::config()?.build()?; common::echo::make_echo_request(config.clone(), |port| async move { let connector = HttpsConnector::new(config.clone()); let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let uri = Uri::from_str(format!("https://localhost:{}", port).as_str())?; let response = client.get(uri).await?; assert_eq!(response.status(), 200); Ok(()) }) .await?; Ok(()) } #[tokio::test] async fn test_http_methods() -> Result<(), Box<dyn Error + Send + Sync>> { let methods = [Method::GET, Method::POST, Method::PUT, Method::DELETE]; for method in methods { let config = common::config()?.build()?; common::echo::make_echo_request(config.clone(), |port| async move { let connector = HttpsConnector::new(config.clone()); let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let request: Request<Full<Bytes>> = Request::builder() .method(method) .uri(Uri::from_str( format!("https://localhost:{}", port).as_str(), )?) .body(Full::from(TEST_DATA))?; let response = client.request(request).await?; assert_eq!(response.status(), 200); let body = response.into_body().collect().await?.to_bytes(); assert_eq!(body.to_vec().as_slice(), TEST_DATA); Ok(()) }) .await?; } Ok(()) } #[tokio::test] async fn test_large_request() -> Result<(), Box<dyn Error + Send + Sync>> { let config = common::config()?.build()?; common::echo::make_echo_request(config.clone(), |port| async move { let connector = HttpsConnector::new(config.clone()); let client: Client<_, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let request: Request<Full<Bytes>> = Request::builder() .method(Method::POST) .uri(Uri::from_str( format!("https://localhost:{}", port).as_str(), )?) .body(Full::from(LARGE_TEST_DATA))?; let response = client.request(request).await?; assert_eq!(response.status(), 200); let body = response.into_body().collect().await?.to_bytes(); assert_eq!(body.to_vec().as_slice(), LARGE_TEST_DATA); Ok(()) }) .await?; Ok(()) } #[tokio::test] async fn test_sni() -> Result<(), Box<dyn Error + Send + Sync>> { struct TestClientHelloHandler { expected_server_name: &'static str, } impl ClientHelloCallback for TestClientHelloHandler { fn on_client_hello( &self, connection: &mut Connection, ) -> Result<Option<Pin<Box<dyn ConnectionFuture>>>, s2n_tls::error::Error> { let server_name = connection.server_name().unwrap(); assert_eq!(server_name, self.expected_server_name); Ok(None) } } for hostname in ["localhost", "127.0.0.1"] { let mut config = common::config()?; config.set_client_hello_callback(TestClientHelloHandler { // Ensure that the HttpsConnector correctly sets the SNI according to the hostname in // the URI. expected_server_name: hostname, })?; config.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})?; let config = config.build()?; common::echo::make_echo_request(config.clone(), |port| async move { let connector = HttpsConnector::new(config.clone()); let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let uri = Uri::from_str(format!("https://{}:{}", hostname, port).as_str())?; let response = client.get(uri).await?; assert_eq!(response.status(), 200); Ok(()) }) .await?; } Ok(()) } /// This test covers the general customer TLS Error experience. We want to /// confirm that s2n-tls errors are correctly bubbled up and that details can be /// extracted/matched on. #[tokio::test] async fn error_matching() -> Result<(), Box<dyn Error + Send + Sync>> { let (server_task, addr) = { let listener = TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; let server_task = tokio::spawn(serve_echo(listener, common::config()?.build()?)); (server_task, addr) }; let client_task: JoinHandle<Result<(), Box<dyn Error + Send + Sync>>> = tokio::spawn(async move { // the client config won't trust the self-signed cert that the server // uses. let client_config = { let mut builder = config::Config::builder(); builder.set_security_policy(&DEFAULT_TLS13)?; builder.set_max_blinding_delay(0)?; builder.build()? }; let connector = HttpsConnector::new(client_config); let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let uri = Uri::from_str(format!("https://localhost:{}", addr.port()).as_str())?; client.get(uri).await?; panic!("the client request should fail"); }); // expected error: // hyper_util::client::legacy::Error( // Connect, // TlsError( // Error { // code: 335544366, // name: "S2N_ERR_CERT_UNTRUSTED", // message: "Certificate is untrusted", // kind: ProtocolError, // source: Library, // debug: "Error encountered in lib/tls/s2n_x509_validator.c:721", // errno: "No such file or directory", // }, // ), // ) let client_response = client_task.await?; let client_error = client_response.unwrap_err(); let hyper_error: &hyper_util::client::legacy::Error = client_error.downcast_ref().unwrap(); // the error happened when attempting to connect to the endpoint. assert!(hyper_error.is_connect()); let error_source = hyper_error.source().unwrap(); let s2n_tls_hyper_error: &s2n_tls_hyper::error::Error = error_source.downcast_ref().unwrap(); let s2n_tls_error = match s2n_tls_hyper_error { s2n_tls_hyper::error::Error::TlsError(s2n_tls_error) => s2n_tls_error, _ => panic!("unexpected error type"), }; assert_eq!( s2n_tls_error.kind(), s2n_tls::error::ErrorType::ProtocolError ); assert_eq!(s2n_tls_error.name(), "S2N_ERR_CERT_UNTRUSTED"); server_task.abort(); Ok(()) } #[tokio::test] async fn ipv6() -> Result<(), Box<dyn Error + Send + Sync>> { let config = { // The localhost IPv6 certificate contains ::1 in the SAN extension. s2n-tls will not // successfully validate the certificate unless the sever name is properly formatted, and // matches this identity. let localhost_ipv6_cert: &[u8] = include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/../certs/cert_localhost_ipv6.pem" )); let localhost_ipv6_key: &[u8] = include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/../certs/key_localhost_ipv6.pem" )); let mut builder = config::Config::builder(); builder.load_pem(localhost_ipv6_cert, localhost_ipv6_key)?; builder.trust_pem(localhost_ipv6_cert)?; builder.build()? }; // Listen for IPv6 connections. let listener = TcpListener::bind("[::1]:0").await?; let addr = listener.local_addr()?; let mut tasks = tokio::task::JoinSet::new(); tasks.spawn(serve_echo(listener, config.clone())); tasks.spawn(async move { let connector = HttpsConnector::new(config); let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); // Connect to the localhost IPv6 address. s2n-tls hostname verification should ensure that // the certificate contains the `::1` identity (without square brackets). let uri = Uri::from_str(format!("https://[::1]:{}", addr.port()).as_str())?; let response = client.get(uri).await?; assert_eq!(response.status(), 200); Ok(()) }); while let Some(res) = tasks.join_next().await { res.unwrap()?; } Ok(()) } #[tokio::test] async fn http2() -> Result<(), Box<dyn Error + Send + Sync>> { for expected_http_version in [Version::HTTP_11, Version::HTTP_2] { let server_config = { let mut builder = common::config()?; if expected_http_version == Version::HTTP_2 { builder.set_application_protocol_preference(["h2"])?; } builder.build()? }; common::echo::make_echo_request(server_config.clone(), |port| async move { let connector = HttpsConnector::new(common::config()?.build()?); let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let uri = Uri::from_str(format!("https://localhost:{}", port).as_str())?; let response = client.get(uri).await?; assert_eq!(response.status(), 200); // Ensure that HTTP/2 is negotiated when supported by the server. assert_eq!(response.version(), expected_http_version); Ok(()) }) .await?; } Ok(()) } /// Ensure that HTTP/2 is negotiated, regardless of any pre-configured ALPN values. #[tokio::test] async fn config_alpn_ignored() -> Result<(), Box<dyn Error + Send + Sync>> { let server_config = { let mut builder = common::config()?; builder.set_application_protocol_preference(["h2"])?; builder.build()? }; common::echo::make_echo_request(server_config, |port| async move { let client_config = { let mut builder = common::config()?; // Set an arbitrary non-HTTP/2 ALPN value. builder.set_application_protocol_preference([b"http/1.1"])?; builder.build()? }; let connector = HttpsConnector::new(client_config); let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let uri = Uri::from_str(format!("https://localhost:{}", port).as_str())?; let response = client.get(uri).await?; assert_eq!(response.status(), 200); // Ensure that HTTP/2 was negotiated. assert_eq!(response.version(), Version::HTTP_2); Ok(()) }) .await?; Ok(()) } #[tokio::test] async fn plaintext_http() -> Result<(), Box<dyn Error + Send + Sync>> { let listener = TcpListener::bind("127.0.0.1:0").await?; let addr = listener.local_addr()?; let mut tasks: JoinSet<Result<(), Box<dyn Error + Send + Sync>>> = JoinSet::new(); tasks.spawn(async move { // Listen for HTTP requests on a plain TCP stream. let (tcp_stream, _) = listener.accept().await.unwrap(); let server = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); server .serve_connection(TokioIo::new(tcp_stream), service_fn(echo)) .await?; Ok(()) }); tasks.spawn(async move { for enable_plaintext_http in [false, true] { let connector = { let config = common::config()?.build()?; let mut builder = HttpsConnector::builder(config); builder.with_plaintext_http(enable_plaintext_http); builder.build() }; let client: Client<_, Empty<Bytes>> = Client::builder(TokioExecutor::new()).build(connector); let uri = Uri::from_str(format!("http://127.0.0.1:{}", addr.port()).as_str())?; let response = client.get(uri).await; if enable_plaintext_http { // If plaintext HTTP is enabled, the request should succeed. let response = response.unwrap(); assert_eq!(response.status(), 200); } else { // If plaintext HTTP is disabled, the request should error. let error = response.unwrap_err(); // Ensure an InvalidScheme error is produced. let error = error .source() .unwrap() .downcast_ref::<error::Error>() .unwrap(); assert!(matches!(error, error::Error::InvalidScheme)); assert!(!error.to_string().is_empty()); } } Ok(()) }); while let Some(res) = tasks.join_next().await { res.unwrap()?; } Ok(()) }