// SPDX-License-Identifier: MIT use std::{ future::Future, io, pin::Pin, task::{Context, Poll}, }; use crate::{AsyncSocket, SocketAddr}; /// Support trait for [`AsyncSocket`] /// /// Provides awaitable variants of the poll functions from [`AsyncSocket`]. pub trait AsyncSocketExt: AsyncSocket { /// `async fn send(&mut self, buf: &[u8]) -> io::Result` fn send<'a, 'b>(&'a mut self, buf: &'b [u8]) -> PollSend<'a, 'b, Self> { PollSend { socket: self, buf } } /// `async fn send(&mut self, buf: &[u8]) -> io::Result` fn send_to<'a, 'b>( &'a mut self, buf: &'b [u8], addr: &'b SocketAddr, ) -> PollSendTo<'a, 'b, Self> { PollSendTo { socket: self, buf, addr, } } /// `async fn recv(&mut self, buf: &mut [u8]) -> io::Result<()>` fn recv<'a, 'b, B>( &'a mut self, buf: &'b mut B, ) -> PollRecv<'a, 'b, Self, B> where B: bytes::BufMut, { PollRecv { socket: self, buf } } /// `async fn recv(&mut self, buf: &mut [u8]) -> io::Result` fn recv_from<'a, 'b, B>( &'a mut self, buf: &'b mut B, ) -> PollRecvFrom<'a, 'b, Self, B> where B: bytes::BufMut, { PollRecvFrom { socket: self, buf } } /// `async fn recrecv_from_full(&mut self) -> io::Result<(Vec, /// SocketAddr)>` fn recv_from_full(&mut self) -> PollRecvFromFull<'_, Self> { PollRecvFromFull { socket: self } } } impl AsyncSocketExt for S {} pub struct PollSend<'a, 'b, S> { socket: &'a mut S, buf: &'b [u8], } impl Future for PollSend<'_, '_, S> where S: AsyncSocket, { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this: &mut Self = Pin::into_inner(self); this.socket.poll_send(cx, this.buf) } } pub struct PollSendTo<'a, 'b, S> { socket: &'a mut S, buf: &'b [u8], addr: &'b SocketAddr, } impl Future for PollSendTo<'_, '_, S> where S: AsyncSocket, { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this: &mut Self = Pin::into_inner(self); this.socket.poll_send_to(cx, this.buf, this.addr) } } pub struct PollRecv<'a, 'b, S, B> { socket: &'a mut S, buf: &'b mut B, } impl Future for PollRecv<'_, '_, S, B> where S: AsyncSocket, B: bytes::BufMut, { type Output = io::Result<()>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this: &mut Self = Pin::into_inner(self); this.socket.poll_recv(cx, this.buf) } } pub struct PollRecvFrom<'a, 'b, S, B> { socket: &'a mut S, buf: &'b mut B, } impl Future for PollRecvFrom<'_, '_, S, B> where S: AsyncSocket, B: bytes::BufMut, { type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this: &mut Self = Pin::into_inner(self); this.socket.poll_recv_from(cx, this.buf) } } pub struct PollRecvFromFull<'a, S> { socket: &'a mut S, } impl Future for PollRecvFromFull<'_, S> where S: AsyncSocket, { type Output = io::Result<(Vec, SocketAddr)>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this: &mut Self = Pin::into_inner(self); this.socket.poll_recv_from_full(cx) } }