use std::collections::VecDeque; use std::pin::Pin; use std::task::{Context, Poll, Waker}; use futures_util::{Sink, Stream}; use pin_project::pin_project; #[pin_project] #[derive(Debug)] pub struct SimpleChannel { queue: VecDeque, waker: Option, closed: bool, } impl SimpleChannel { pub fn new() -> Self { SimpleChannel { queue: VecDeque::new(), waker: None, closed: false, } } } impl Default for SimpleChannel { fn default() -> Self { Self::new() } } impl Stream for SimpleChannel { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.project(); match this.queue.pop_front() { Some(item) => Poll::Ready(Some(item)), None if *this.closed => Poll::Ready(None), None => { *this.waker = Some(cx.waker().clone()); Poll::Pending } } } fn size_hint(&self) -> (usize, Option) { (self.queue.len(), None) } } impl Sink for SimpleChannel { type Error = (); fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { let this = self.project(); this.queue.push_back(item); if let Some(waker) = this.waker.take() { waker.wake(); } Ok(()) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let this = self.project(); *this.closed = true; if let Some(waker) = this.waker.take() { waker.wake(); } Poll::Ready(Ok(())) } } #[cfg(test)] mod tests { use futures_util::future::join; use futures_util::stream::iter; use futures_util::{SinkExt, StreamExt}; use super::*; #[tokio::test] async fn test_write_then_read() { let channel = SimpleChannel::::new(); let (mut sink, mut stream) = channel.split(); send_many(&mut sink, vec![1, 2, 3]).await.unwrap(); assert_eq!(stream.next().await.unwrap(), 1); assert_eq!(stream.next().await.unwrap(), 2); send_many(&mut sink, vec![4, 5]).await.unwrap(); assert_eq!(stream.next().await.unwrap(), 3); assert_eq!(stream.next().await.unwrap(), 4); assert_eq!(stream.next().await.unwrap(), 5); sink.close().await.unwrap(); assert_eq!(stream.next().await, None); } #[tokio::test] async fn test_read_then_write() { let channel = SimpleChannel::::new(); let (mut sink, mut stream) = channel.split(); join( async { assert_eq!(stream.next().await.unwrap(), 1); assert_eq!(stream.next().await.unwrap(), 2); assert_eq!(stream.next().await.unwrap(), 3); }, async { send_many(&mut sink, vec![1, 2, 3, 4]).await.unwrap(); }, ) .await; } #[tokio::test] async fn test_read_then_close() { let channel = SimpleChannel::::new(); let (mut sink, mut stream) = channel.split(); join( async { assert_eq!(stream.next().await, None); }, async { sink.close().await.unwrap(); }, ) .await; } #[tokio::test] async fn test_close_then_read() { let channel = SimpleChannel::::new(); let (mut sink, stream) = channel.split(); send_many(&mut sink, vec![1, 2, 3]).await.unwrap(); sink.close().await.unwrap(); // should still read items from queue assert_eq!(stream.collect::>().await, vec![1, 2, 3]); } async fn send_many( sink: &mut Si, values: impl IntoIterator, ) -> Result<(), Si::Error> where Si: Sink + Unpin, { sink.send_all(&mut iter(values).map(Ok)).await } }