// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use rand::Rng;
use s2n_tls::{
    config::Config,
    connection::{Connection, ModifiedBuilder},
    enums::{ClientAuthType, Mode, Version},
    error::{Error, ErrorType},
    pool::ConfigPoolBuilder,
    security::DEFAULT_TLS13,
};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector};
use std::{collections::VecDeque, time::Duration};
use tokio::time;

pub mod common;

#[tokio::test]
async fn handshake_basic() -> Result<(), Box<dyn std::error::Error>> {
    let (server_stream, client_stream) = common::get_streams().await?;

    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (client_result, server_result) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    for tls in [client_result, server_result] {
        // Security policy ensures TLS1.3.
        assert_eq!(tls.as_ref().actual_protocol_version()?, Version::TLS13);
        // Handshake types may change, but will at least be negotiated.
        assert!(tls.as_ref().handshake_type()?.contains("NEGOTIATED"));
        // Cipher suite may change, so just makes sure we can retrieve it.
        assert!(tls.as_ref().cipher_suite().is_ok());
        assert!(tls.as_ref().selected_curve().is_ok());
    }

    Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn handshake_with_pool_multithread() -> Result<(), Box<dyn std::error::Error>> {
    const COUNT: usize = 20;
    const CLIENT_LIMIT: usize = 3;

    let client_config = common::client_config()?.build()?;
    let server_config = common::server_config()?.build()?;

    let mut client_pool = ConfigPoolBuilder::new(Mode::Client, client_config);
    client_pool.set_max_pool_size(CLIENT_LIMIT);

    let client_pool = client_pool.build();
    let server_pool = ConfigPoolBuilder::new(Mode::Server, server_config).build();

    let client = TlsConnector::new(client_pool.clone());
    let server = TlsAcceptor::new(server_pool.clone());

    let mut tasks = VecDeque::new();
    for _ in 0..COUNT {
        let client = client.clone();
        let server = server.clone();
        tasks.push_back(tokio::spawn(async move {
            // Start each handshake at a randomly determined time
            let rand = rand::rng().random_range(0..50);
            time::sleep(Duration::from_millis(rand)).await;

            let (server_stream, client_stream) = common::get_streams().await.unwrap();
            common::run_negotiate(&client, client_stream, &server, server_stream).await
        }));
    }

    for task in tasks {
        task.await??;
    }
    Ok(())
}

#[tokio::test]
async fn handshake_with_connection_config() -> Result<(), Box<dyn std::error::Error>> {
    // Setup the client with a method
    fn with_client_auth(conn: &mut Connection) -> Result<&mut Connection, Error> {
        conn.set_client_auth_type(ClientAuthType::Optional)
    }
    let client_builder = ModifiedBuilder::new(common::client_config()?.build()?, with_client_auth);

    // Setup the server with a closure
    let server_builder = ModifiedBuilder::new(common::server_config()?.build()?, |conn| {
        conn.set_client_auth_type(ClientAuthType::Optional)
    });

    let client = TlsConnector::new(client_builder);
    let server = TlsAcceptor::new(server_builder);

    let (server_stream, client_stream) = common::get_streams().await?;
    let (client_result, server_result) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    for tls in [client_result, server_result] {
        assert!(tls.as_ref().handshake_type()?.contains("CLIENT_AUTH"));
    }

    Ok(())
}

#[tokio::test]
async fn handshake_with_connection_config_with_pool() -> Result<(), Box<dyn std::error::Error>> {
    fn with_client_auth(conn: &mut Connection) -> Result<&mut Connection, Error> {
        conn.set_client_auth_type(ClientAuthType::Optional)
    }
    let client_builder = ModifiedBuilder::new(common::client_config()?.build()?, with_client_auth);
    let server_pool =
        ConfigPoolBuilder::new(Mode::Server, common::server_config()?.build()?).build();
    let server_builder = ModifiedBuilder::new(server_pool, with_client_auth);

    let client = TlsConnector::new(client_builder);
    let server = TlsAcceptor::new(server_builder);

    for _ in 0..5 {
        let (server_stream, client_stream) = common::get_streams().await?;
        let (_, server_result) =
            common::run_negotiate(&client, client_stream, &server, server_stream).await?;
        assert!(server_result
            .as_ref()
            .handshake_type()?
            .contains("CLIENT_AUTH"));
    }

    Ok(())
}

#[tokio::test]
async fn handshake_error() -> Result<(), Box<dyn std::error::Error>> {
    // Config::default() does not include any RSA certificates,
    // but only provides TLS1.2 cipher suites that require RSA auth.
    // The server will fail to choose a cipher suite, but
    // S2N_ERR_CIPHER_NOT_SUPPORTED is specifically excluded from blinding.
    let bad_config = Config::default();
    let client_config = common::client_config()?.build()?;
    let server_config = bad_config;

    let client = TlsConnector::new(client_config);
    let server = TlsAcceptor::new(server_config);

    let (server_stream, client_stream) = common::get_streams().await?;
    let result = common::run_negotiate(&client, client_stream, &server, server_stream).await;
    assert!(matches!(result, Err(e) if !e.is_retryable()));

    Ok(())
}

#[tokio::test(start_paused = true)]
async fn handshake_error_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
    let clock = common::TokioTime::default();

    // Config::builder() does not include a trust store.
    // The client will reject the server certificate as untrusted.
    let mut bad_config = Config::builder();
    bad_config.set_security_policy(&DEFAULT_TLS13)?;
    bad_config.set_monotonic_clock(clock)?;
    let client_config = bad_config.build()?;
    let server_config = common::server_config()?.build()?;

    let client = TlsConnector::new(client_config.clone());
    let server = TlsAcceptor::new(server_config.clone());
    let (server_stream, client_stream) = common::get_streams().await?;

    let time_start = time::Instant::now();
    let result = common::run_negotiate(&client, client_stream, &server, server_stream).await;
    let time_elapsed = time_start.elapsed();

    // Handshake MUST NOT finish faster than minimal blinding time.
    assert!(time_elapsed > common::MIN_BLINDING_SECS);

    // Handshake MUST eventually gracefully fail after blinding
    let error = result.unwrap_err();
    assert_eq!(error.kind(), ErrorType::ProtocolError);

    Ok(())
}

#[tokio::test]
async fn io_stream_access() -> Result<(), Box<dyn std::error::Error>> {
    let (server_stream, client_stream) = common::get_streams().await?;

    let client_addr = client_stream.local_addr().unwrap();
    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (mut client_result, _server_result) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    assert_eq!(client_result.get_ref().local_addr().unwrap(), client_addr);
    assert_eq!(client_result.get_mut().local_addr().unwrap(), client_addr);

    Ok(())
}

#[tokio::test]
async fn handshake_with_async_callback() -> Result<(), Box<dyn std::error::Error>> {
    use core::{future::Future, pin::Pin, task::Poll};
    use s2n_tls::{callbacks::*, connection, error};
    use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};

    let client_config = common::client_config()?.build()?;

    let mut server_config = common::server_config()?;
    server_config.set_client_hello_callback(DelayClientHelloHandler {
        amount: Duration::from_secs(1),
    })?;
    let server_config = server_config.build()?;

    let client = TlsConnector::new(client_config.clone());
    let server = TlsAcceptor::new(server_config.clone());
    let (server_stream, client_stream) = common::get_streams().await?;

    let mut tasks = tokio::task::JoinSet::new();

    tasks.spawn(async move {
        let mut stream = client.connect("localhost", client_stream).await.unwrap();
        stream.shutdown().await.unwrap();
        let len = stream.read(&mut [0]).await.unwrap();
        assert_eq!(len, 0);
    });

    tasks.spawn(async move {
        let mut stream = server.accept(server_stream).await.unwrap();
        stream.shutdown().await.unwrap();
        let len = stream.read(&mut [0]).await.unwrap();
        assert_eq!(len, 0);
    });

    // make sure the tasks completed
    while let Some(res) = tasks.join_next().await {
        res.unwrap();
    }

    /// Adds an artificial delay to a ClientHello callback
    #[derive(Clone)]
    pub struct DelayClientHelloHandler {
        amount: Duration,
    }

    impl ClientHelloCallback for DelayClientHelloHandler {
        fn on_client_hello(
            &self,
            _connection: &mut connection::Connection,
        ) -> Result<Option<Pin<Box<dyn ConnectionFuture>>>, error::Error> {
            Ok(Some(Box::pin(DelayClientHelloFuture {
                timer: Box::pin(tokio::time::sleep(self.amount)),
            })))
        }
    }

    pub struct DelayClientHelloFuture {
        timer: Pin<Box<tokio::time::Sleep>>,
    }

    impl ConnectionFuture for DelayClientHelloFuture {
        fn poll(
            mut self: Pin<&mut Self>,
            _connection: &mut connection::Connection,
            ctx: &mut core::task::Context,
        ) -> Poll<Result<(), error::Error>> {
            self.timer.as_mut().poll(ctx).map(Ok)
        }
    }

    Ok(())
}