// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{
    io,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
};
use tokio::{
    io::{AsyncRead, AsyncWrite, ReadBuf},
    net::TcpStream,
};

type ReadFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &mut ReadBuf) -> Poll<io::Result<()>>>;
type WriteFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &[u8]) -> Poll<io::Result<usize>>>;
type ShutdownFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context) -> Poll<io::Result<()>>>;

#[derive(Default)]
struct OverrideMethods {
    next_read: Option<ReadFn>,
    next_write: Option<WriteFn>,
    next_shutdown: Option<ShutdownFn>,
}

#[derive(Default)]
pub struct Overrides(Mutex<OverrideMethods>);

impl Overrides {
    pub fn next_read(&self, input: Option<ReadFn>) {
        if let Ok(mut overrides) = self.0.lock() {
            overrides.next_read = input;
        }
    }

    pub fn next_write(&self, input: Option<WriteFn>) {
        if let Ok(mut overrides) = self.0.lock() {
            overrides.next_write = input;
        }
    }

    pub fn next_shutdown(&self, input: Option<ShutdownFn>) {
        if let Ok(mut overrides) = self.0.lock() {
            overrides.next_shutdown = input;
        }
    }

    pub fn is_consumed(&self) -> bool {
        if let Ok(overrides) = self.0.lock() {
            overrides.next_read.is_none()
                && overrides.next_write.is_none()
                && overrides.next_shutdown.is_none()
        } else {
            false
        }
    }
}

unsafe impl Send for Overrides {}
unsafe impl Sync for Overrides {}

pub struct TestStream {
    stream: TcpStream,
    overrides: Arc<Overrides>,
}

impl TestStream {
    pub fn new(stream: TcpStream) -> Self {
        let overrides = Arc::new(Overrides::default());
        Self { stream, overrides }
    }

    pub fn overrides(&self) -> Arc<Overrides> {
        self.overrides.clone()
    }
}

impl AsyncRead for TestStream {
    fn poll_read(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        let s = self.get_mut();
        let stream = Pin::new(&mut s.stream);
        let action = match s.overrides.0.lock() {
            Ok(mut overrides) => overrides.next_read.take(),
            _ => None,
        };
        if let Some(f) = action {
            (f)(stream, ctx, buf)
        } else {
            stream.poll_read(ctx, buf)
        }
    }
}

impl AsyncWrite for TestStream {
    fn poll_write(
        self: Pin<&mut Self>,
        ctx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let s = self.get_mut();
        let stream = Pin::new(&mut s.stream);
        let action = match s.overrides.0.lock() {
            Ok(mut overrides) => overrides.next_write.take(),
            _ => None,
        };
        if let Some(f) = action {
            (f)(stream, ctx, buf)
        } else {
            stream.poll_write(ctx, buf)
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.stream).poll_flush(ctx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
        let s = self.get_mut();
        let stream = Pin::new(&mut s.stream);
        let action = match s.overrides.0.lock() {
            Ok(mut overrides) => overrides.next_shutdown.take(),
            _ => None,
        };
        if let Some(f) = action {
            (f)(stream, ctx)
        } else {
            stream.poll_shutdown(ctx)
        }
    }
}