// SPDX-License-Identifier: MIT

use std::{
    future::Future,
    io,
    pin::Pin,
    task::{Context, Poll},
};

use crate::{AsyncSocket, SocketAddr};

/// Support trait for [`AsyncSocket`]
///
/// Provides awaitable variants of the poll functions from [`AsyncSocket`].
pub trait AsyncSocketExt: AsyncSocket {
    /// `async fn send(&mut self, buf: &[u8]) -> io::Result<usize>`
    fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> PollSend<'a, 'b, Self> {
        PollSend { socket: self, buf }
    }

    /// `async fn send(&mut self, buf: &[u8]) -> io::Result<usize>`
    fn send_to<'a, 'b>(
        &'a mut self,
        buf: &'b [u8],
        addr: &'b SocketAddr,
    ) -> PollSendTo<'a, 'b, Self> {
        PollSendTo {
            socket: self,
            buf,
            addr,
        }
    }

    /// `async fn recv<B>(&mut self, buf: &mut [u8]) -> io::Result<()>`
    fn recv<'a, 'b, B>(&'a mut self, buf: &'b mut B) -> PollRecv<'a, 'b, Self, B>
    where
        B: bytes::BufMut,
    {
        PollRecv { socket: self, buf }
    }

    /// `async fn recv<B>(&mut self, buf: &mut [u8]) -> io::Result<SocketAddr>`
    fn recv_from<'a, 'b, B>(&'a mut self, buf: &'b mut B) -> PollRecvFrom<'a, 'b, Self, B>
    where
        B: bytes::BufMut,
    {
        PollRecvFrom { socket: self, buf }
    }

    /// `async fn recrecv_from_full(&mut self) -> io::Result<(Vec<u8>, SocketAddr)>`
    fn recv_from_full(&mut self) -> PollRecvFromFull<'_, Self> {
        PollRecvFromFull { socket: self }
    }
}

impl<S: AsyncSocket> AsyncSocketExt for S {}

pub struct PollSend<'a, 'b, S> {
    socket: &'a mut S,
    buf: &'b [u8],
}

impl<S> Future for PollSend<'_, '_, S>
where
    S: AsyncSocket,
{
    type Output = io::Result<usize>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this: &mut Self = Pin::into_inner(self);
        this.socket.poll_send(cx, this.buf)
    }
}

pub struct PollSendTo<'a, 'b, S> {
    socket: &'a mut S,
    buf: &'b [u8],
    addr: &'b SocketAddr,
}

impl<S> Future for PollSendTo<'_, '_, S>
where
    S: AsyncSocket,
{
    type Output = io::Result<usize>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this: &mut Self = Pin::into_inner(self);
        this.socket.poll_send_to(cx, this.buf, this.addr)
    }
}

pub struct PollRecv<'a, 'b, S, B> {
    socket: &'a mut S,
    buf: &'b mut B,
}

impl<S, B> Future for PollRecv<'_, '_, S, B>
where
    S: AsyncSocket,
    B: bytes::BufMut,
{
    type Output = io::Result<()>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this: &mut Self = Pin::into_inner(self);
        this.socket.poll_recv(cx, this.buf)
    }
}

pub struct PollRecvFrom<'a, 'b, S, B> {
    socket: &'a mut S,
    buf: &'b mut B,
}

impl<S, B> Future for PollRecvFrom<'_, '_, S, B>
where
    S: AsyncSocket,
    B: bytes::BufMut,
{
    type Output = io::Result<SocketAddr>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this: &mut Self = Pin::into_inner(self);
        this.socket.poll_recv_from(cx, this.buf)
    }
}

pub struct PollRecvFromFull<'a, S> {
    socket: &'a mut S,
}

impl<S> Future for PollRecvFromFull<'_, S>
where
    S: AsyncSocket,
{
    type Output = io::Result<(Vec<u8>, SocketAddr)>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this: &mut Self = Pin::into_inner(self);
        this.socket.poll_recv_from_full(cx)
    }
}