// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use s2n_tls::error;
use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream};
use std::{
    convert::TryFrom,
    io,
    sync::Arc,
    task::Poll::{Pending, Ready},
};
use tokio::{
    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
    time,
};

pub mod common;

async fn read_until_shutdown<S: AsyncRead + AsyncWrite + Unpin>(
    stream: &mut TlsStream<S>,
) -> Result<(), std::io::Error> {
    let mut received = [0; 1];
    // Zero bytes read indicates EOF
    while stream.read(&mut received).await? != 0 {}
    stream.shutdown().await
}

async fn write_until_shutdown<S: AsyncWrite + Unpin>(stream: &mut S) -> Result<(), std::io::Error> {
    let sent = [0; 1];
    loop {
        if let Err(err) = stream.write(&sent).await {
            let tls_err = error::Error::try_from(err).unwrap();
            assert_eq!(tls_err.kind(), error::ErrorType::ConnectionClosed);
            break;
        }
    }
    stream.shutdown().await
}

#[tokio::test]
async fn client_initiated_shutdown() -> Result<(), Box<dyn std::error::Error>> {
    let (server_stream, client_stream) = common::get_streams().await?;

    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (mut client, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    tokio::try_join!(read_until_shutdown(&mut server), client.shutdown())?;

    Ok(())
}

#[tokio::test]
async fn server_initiated_shutdown() -> Result<(), Box<dyn std::error::Error>> {
    let (server_stream, client_stream) = common::get_streams().await?;

    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (mut client, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    tokio::try_join!(read_until_shutdown(&mut client), server.shutdown())?;

    Ok(())
}

/// Reading and writing handles should both respond to a peer's "close notify"
/// appropriately. The read handle should immediately exit and writing should
/// fail with a "connection closed" error.
#[tokio::test]
async fn shutdown_after_split() -> Result<(), Box<dyn std::error::Error>> {
    let (server_stream, client_stream) = common::get_streams().await?;

    let client = TlsConnector::new(common::client_config_tls12()?.build()?);
    let server = TlsAcceptor::new(common::server_config_tls12()?.build()?);

    let (client, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    let (mut client_reader, mut client_writer) = tokio::io::split(client);

    let mut received = [0; 1];

    // All tasks must cleanly exit. try_join will return as soon as an error
    // occurs, so if the result is any error then the test has failed.
    tokio::try_join!(
        server.shutdown(),
        client_reader.read(&mut received),
        write_until_shutdown(&mut client_writer),
    )?;
    Ok(())
}

/// Reading and writing handles should both respond to a peers "close notify"
/// appropriately. TLS1.3 connections have "half close behavior". The read
/// handle should immediately exit, but the write handle can continue to write
/// until explicitly shutdown. After both client handles have shutdown, the
/// server should cleanly exit.
#[tokio::test]
async fn shutdown_after_halfclose_split() -> Result<(), Box<dyn std::error::Error>> {
    let (server_stream, client_stream) = common::get_streams().await?;

    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (client, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    let (mut client_reader, mut client_writer) = tokio::io::split(client);

    let close_notify_recvd = Arc::new(tokio::sync::Notify::new());
    let close_notify_recvd_clone = close_notify_recvd.clone();

    let mut received = [0; 1];

    // all tasks must complete, and must complete successfully
    // the client tasks will panic if an error is encountered, so those don't
    // need to be checked.
    let (server, _, _) = tokio::join!(
        server.shutdown(),
        async {
            let bytes_read = client_reader.read(&mut received).await.unwrap();
            // 0 bytes read indicate that we returned because of close notify
            assert_eq!(bytes_read, 0);
            // signal the writer task that close notify received
            close_notify_recvd.notify_one();
        },
        async {
            // wait for the connection to receive "close notify" from peer
            close_notify_recvd_clone.notified().await;
            // confirm that we can write even after receiving the shutdown from
            // the server
            client_writer
                .write_all("random bytes".as_bytes())
                .await
                .unwrap();
            client_writer.flush().await.unwrap();
            // shutdown
            client_writer.shutdown().await.unwrap()
        }
    );
    // make sure the server shutdown cleanly
    assert!(server.is_ok());
    Ok(())
}

#[tokio::test(start_paused = true)]
async fn shutdown_with_blinding() -> Result<(), Box<dyn std::error::Error>> {
    let clock = common::TokioTime::default();
    let mut server_config = common::server_config()?;
    server_config.set_monotonic_clock(clock)?;

    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(server_config.build()?);

    let (server_stream, client_stream) = common::get_streams().await?;
    let server_stream = common::TestStream::new(server_stream);
    let overrides = server_stream.overrides();
    let (mut client, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    // Setup a bad record for the next read
    overrides.next_read(Some(Box::new(|_, _, buf| {
        // Parsing the header is one of the blinded operations
        // in s2n_recv, so provide a malformed header.
        let zeroed_header = [23, 0, 0, 0, 0];
        buf.put_slice(&zeroed_header);
        Ok(()).into()
    })));

    // Trigger the blinded error
    let mut received = [0; 1];
    let result = server.read_exact(&mut received).await;
    assert!(result.is_err());

    let time_start = time::Instant::now();
    let result = server.shutdown().await;
    let time_elapsed = time_start.elapsed();

    // Shutdown MUST NOT complete faster than minimal blinding time.
    assert!(time_elapsed > common::MIN_BLINDING_SECS);

    // Server MUST eventually successfully shutdown
    assert!(result.is_ok());

    // Shutdown MUST have sent the close_notify message needed for EOF.
    let mut received = [0; 1];
    assert!(client.read(&mut received).await? == 0);

    Ok(())
}

#[tokio::test(start_paused = true)]
async fn shutdown_with_poll_blinding() -> Result<(), Box<dyn std::error::Error>> {
    let clock = common::TokioTime::default();
    let mut server_config = common::server_config()?;
    server_config.set_monotonic_clock(clock)?;

    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(server_config.build()?);

    let (server_stream, client_stream) = common::get_streams().await?;
    let server_stream = common::TestStream::new(server_stream);
    let overrides = server_stream.overrides();
    let (_, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    // Setup a bad record for the next read
    overrides.next_read(Some(Box::new(|_, _, buf| {
        // Parsing the header is one of the blinded operations
        // in s2n_recv, so provide a malformed header.
        let zeroed_header = [23, 0, 0, 0, 0];
        buf.put_slice(&zeroed_header);
        Ok(()).into()
    })));

    // Trigger the blinded error
    let mut received = [0; 1];
    let result = server.read_exact(&mut received).await;
    assert!(result.is_err());

    let time_start = time::Instant::now();
    let result = server.apply_blinding().await;
    let time_elapsed = time_start.elapsed();

    // poll_blinding MUST NOT complete faster than minimal blinding time.
    assert!(time_elapsed > common::MIN_BLINDING_SECS);

    // poll_blinding MUST eventually complete
    assert!(result.is_ok());

    Ok(())
}

#[tokio::test]
async fn shutdown_with_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (server_stream, client_stream) = common::get_streams().await?;
    let server_stream = common::TestStream::new(server_stream);
    let overrides = server_stream.overrides();

    let (_, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    // The underlying stream should return a unique error on shutdown
    overrides.next_shutdown(Some(Box::new(|_, _| {
        Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
    })));

    // Shutdown should complete with the correct error from the underlying stream
    let result = server.shutdown().await;
    let error = result.unwrap_err().into_inner().unwrap();
    assert!(error.to_string() == common::TEST_STR);

    Ok(())
}

#[tokio::test]
async fn shutdown_with_tls_error_and_tcp_error() -> Result<(), Box<dyn std::error::Error>> {
    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (server_stream, client_stream) = common::get_streams().await?;
    let server_stream = common::TestStream::new(server_stream);
    let overrides = server_stream.overrides();

    let (_, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    // Both s2n_shutdown_send and the underlying stream should error on shutdown
    overrides.next_write(Some(Box::new(|_, _, _| {
        Ready(Err(io::Error::from(io::ErrorKind::Other)))
    })));
    overrides.next_shutdown(Some(Box::new(|_, _| {
        Ready(Err(io::Error::new(io::ErrorKind::Other, common::TEST_STR)))
    })));

    // Shutdown should complete with the correct error from s2n_shutdown_send
    let result = server.shutdown().await;
    let io_error = result.unwrap_err();
    let error: error::Error = io_error.try_into()?;
    // Any non-blocking read error is translated as "IOError"
    assert!(error.kind() == error::ErrorType::IOError);

    // Even if s2n_shutdown_send fails, we need to close the underlying stream.
    // Make sure we called our mock shutdown, consuming it.
    assert!(overrides.is_consumed());

    Ok(())
}

#[tokio::test]
async fn shutdown_with_tls_error_and_tcp_delay() -> Result<(), Box<dyn std::error::Error>> {
    let client = TlsConnector::new(common::client_config()?.build()?);
    let server = TlsAcceptor::new(common::server_config()?.build()?);

    let (server_stream, client_stream) = common::get_streams().await?;
    let server_stream = common::TestStream::new(server_stream);
    let overrides = server_stream.overrides();

    let (mut client, mut server) =
        common::run_negotiate(&client, client_stream, &server, server_stream).await?;

    // We want s2n_shutdown_send to produce an error on write
    overrides.next_write(Some(Box::new(|_, _, _| {
        Ready(Err(io::Error::from(io::ErrorKind::Other)))
    })));

    // The underlying stream should initially return Pending, delaying shutdown
    overrides.next_shutdown(Some(Box::new(|_, ctx| {
        ctx.waker().wake_by_ref();
        Pending
    })));

    // Shutdown should complete with the correct error from s2n_shutdown_send
    let result = server.shutdown().await;
    let io_error = result.unwrap_err();
    let error: error::Error = io_error.try_into()?;
    // Any non-blocking read error is translated as "IOError"
    assert!(error.kind() == error::ErrorType::IOError);

    // Even if s2n_shutdown_send fails, we need to close the underlying stream.
    // Make sure we at least called our mock shutdown, consuming it.
    assert!(overrides.is_consumed());

    // Since s2n_shutdown_send failed, we should NOT have sent a close_notify.
    // Make sure the peer doesn't receive a close_notify.
    // If this is not true, then we're incorrectly calling s2n_shutdown_send
    // again after an error.
    let mut received = [0; 1];
    let io_error = client.read(&mut received).await.unwrap_err();
    let error: error::Error = io_error.try_into()?;
    assert!(error.kind() == error::ErrorType::ConnectionClosed);

    Ok(())
}