use futures::future::poll_fn; use std::{ io::IoSlice, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio_util::io::{InspectReader, InspectWriter}; /// An AsyncRead implementation that works byte-by-byte, to catch out callers /// who don't allow for `buf` being part-filled before the call struct SmallReader { contents: Vec, } impl Unpin for SmallReader {} impl AsyncRead for SmallReader { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { if let Some(byte) = self.contents.pop() { buf.put_slice(&[byte]) } Poll::Ready(Ok(())) } } #[tokio::test] async fn read_tee() { let contents = b"This could be really long, you know".to_vec(); let reader = SmallReader { contents: contents.clone(), }; let mut altout: Vec = Vec::new(); let mut teeout = Vec::new(); { let mut tee = InspectReader::new(reader, |bytes| altout.extend(bytes)); tee.read_to_end(&mut teeout).await.unwrap(); } assert_eq!(teeout, altout); assert_eq!(altout.len(), contents.len()); } /// An AsyncWrite implementation that works byte-by-byte for poll_write, and /// that reads the whole of the first buffer plus one byte from the second in /// poll_write_vectored. /// /// This is designed to catch bugs in handling partially written buffers #[derive(Debug)] struct SmallWriter { contents: Vec, } impl Unpin for SmallWriter {} impl AsyncWrite for SmallWriter { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { // Just write one byte at a time if buf.is_empty() { return Poll::Ready(Ok(0)); } self.contents.push(buf[0]); Poll::Ready(Ok(1)) } fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn poll_shutdown( self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { Poll::Ready(Ok(())) } fn poll_write_vectored( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { // Write all of the first buffer, then one byte from the second buffer // This should trip up anything that doesn't correctly handle multiple // buffers. if bufs.is_empty() { return Poll::Ready(Ok(0)); } let mut written_len = bufs[0].len(); self.contents.extend_from_slice(&bufs[0]); if bufs.len() > 1 { let buf = bufs[1]; if !buf.is_empty() { written_len += 1; self.contents.push(buf[0]); } } Poll::Ready(Ok(written_len)) } fn is_write_vectored(&self) -> bool { true } } #[tokio::test] async fn write_tee() { let mut altout: Vec = Vec::new(); let mut writeout = SmallWriter { contents: Vec::new(), }; { let mut tee = InspectWriter::new(&mut writeout, |bytes| altout.extend(bytes)); tee.write_all(b"A testing string, very testing") .await .unwrap(); } assert_eq!(altout, writeout.contents); } // This is inefficient, but works well enough for test use. // If you want something similar for real code, you'll want to avoid all the // fun of manipulating `bufs` - ideally, by the time you read this, // IoSlice::advance_slices will be stable, and you can use that. async fn write_all_vectored( mut writer: W, mut bufs: Vec>, ) -> Result { let mut res = 0; while !bufs.is_empty() { let mut written = poll_fn(|cx| { let bufs: Vec = bufs.iter().map(|v| IoSlice::new(v)).collect(); Pin::new(&mut writer).poll_write_vectored(cx, &bufs) }) .await?; res += written; while written > 0 { let buf_len = bufs[0].len(); if buf_len <= written { bufs.remove(0); written -= buf_len; } else { let buf = &mut bufs[0]; let drain_len = written.min(buf.len()); buf.drain(..drain_len); written -= drain_len; } } } Ok(res) } #[tokio::test] async fn write_tee_vectored() { let mut altout: Vec = Vec::new(); let mut writeout = SmallWriter { contents: Vec::new(), }; let original = b"A very long string split up"; let bufs: Vec> = original .split(|b| b.is_ascii_whitespace()) .map(Vec::from) .collect(); assert!(bufs.len() > 1); let expected: Vec = { let mut out = Vec::new(); for item in &bufs { out.extend_from_slice(item) } out }; { let mut bufcount = 0; let tee = InspectWriter::new(&mut writeout, |bytes| { bufcount += 1; altout.extend(bytes) }); assert!(tee.is_write_vectored()); write_all_vectored(tee, bufs.clone()).await.unwrap(); assert!(bufcount >= bufs.len()); } assert_eq!(altout, writeout.contents); assert_eq!(writeout.contents, expected); }