use std::marker::PhantomData; use anyhow::Context; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_core::{Stream, TryStream}; use futures_util::{StreamExt, TryStreamExt}; use serde::{Deserialize, Serialize}; use tonic::metadata::MetadataMap; use crate::{ codec::{ConnectDecoder, ConnectEncoder, EncodeBuf}, error::Error, status::{ConnectCode, ConnectStatus}, }; #[derive(Clone, Debug)] pub enum Frame { Data(T), End(EndFrame), } impl Frame { pub fn encode(self, encoder: &mut impl ConnectEncoder) -> Result { match self { Frame::Data(data) => { let size_hint = encoder.size_hint(&data); let mut buf = FrameBuf::new(false, false, size_hint); encoder.encode(data, &mut *buf)?; buf.into_bytes() } Frame::End(end_frame) => end_frame.to_bytes(), } } pub fn decode( header: &FrameHeader, buf: impl Buf, decoder: &mut impl ConnectDecoder, ) -> Result { let data = header.take_data(buf)?; Ok(if header.is_end() { Self::End(serde_json::from_reader(data.reader())?) } else { Self::Data(decoder.decode(data)?) }) } pub fn encode_stream<'a>( stream: impl Stream> + 'a, mut encoder: impl ConnectEncoder + 'a, ) -> impl Stream> + 'a { stream.map(move |res| res.and_then(|frame| frame.encode(&mut encoder))) } pub fn decode_stream( stream: impl Stream>, mut decoder: impl ConnectDecoder, ) -> impl Stream, Error>> { let mut header: Option = None; let mut buf = BytesMut::new(); stream.try_f(|res| { let bytes = res?; if buf.is_empty() { buf = bytes.into(); } else { buf.extend_from_slice(&bytes); } if header.is_none() { match FrameHeader::decode(&mut bytes) { Ok(header) => { header = Some(header); } Err(Error::NeedMoreData(_)) => { return None; } Err(err) => { return Err(err) } } } }) } } // https://connectrpc.com/docs/protocol/#error-end-stream #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct EndFrame { #[serde(default, skip_serializing_if = "Option::is_none")] pub error: Option, #[serde( default, skip_serializing_if = "MetadataMap::is_empty", with = "crate::metadata::serde" )] pub metadata: MetadataMap, } impl EndFrame { pub fn success(metadata: MetadataMap) -> Self { Self { error: None, metadata, } } pub fn status(status: impl TryInto) -> Self { let status = status.try_into().unwrap_or_else(Into::into); let error = (status.code != ConnectCode::Ok).then_some(status); Self { error, metadata: Default::default(), } } pub fn to_bytes(&self) -> Result { let mut buf = FrameBuf::new(true, false, 2); serde_json::to_writer(&mut buf.writer(), self)?; buf.into_bytes() } } impl Into> for EndFrame { fn into(self) -> Frame { Frame::End(self) } } pub struct FrameBytesStream { inner: T, encoder: U, _phantom: PhantomData, } pub struct BytesFrameStream { inner: T, decoder: U, } pub struct FrameBuf { end: bool, compressed: bool, header_buf: BytesMut, encode_buf: EncodeBuf, } impl FrameBuf { pub fn new(end: bool, compressed: bool, size_hint: usize) -> Self { let mut buf = BytesMut::with_capacity(FrameHeader::LEN + size_hint); let encode_buf = buf.split_off(FrameHeader::LEN).into(); Self { end, compressed, header_buf: buf, encode_buf, } } pub fn into_bytes(self) -> Result { let Self { end, compressed, mut header_buf, encode_buf, } = self; let data_buf = encode_buf.into_inner(); FrameHeader::with_usize_len(end, compressed, data_buf.len())?.encode(&mut header_buf); header_buf.unsplit(data_buf); Ok(header_buf.freeze()) } } impl std::ops::Deref for FrameBuf { type Target = EncodeBuf; fn deref(&self) -> &Self::Target { &self.encode_buf } } impl std::ops::DerefMut for FrameBuf { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.encode_buf } } pub struct FrameBufWriter<'a> { frame_buf: &'a mut FrameBuf, writer: bytes::buf::Writer, } impl<'a> std::ops::Deref for FrameBufWriter<'a> { type Target = bytes::buf::Writer; fn deref(&self) -> &Self::Target { &self.writer } } impl<'a> std::ops::DerefMut for FrameBufWriter<'a> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.writer } } pub struct FrameHeader { flags: u8, data_len: u32, } impl FrameHeader { pub const LEN: usize = 5; pub fn new(end: bool, compressed: bool, data_len: u32) -> Self { // Flag bits: <6 * reserved> let flags = ((end as u8) << 1) + (compressed as u8); Self { flags, data_len } } pub fn with_usize_len(end: bool, compressed: bool, data_len: usize) -> Result { let data_len = data_len .try_into() .context("too large") .map_err(Error::InvalidStreamFrame)?; Ok(Self::new(end, compressed, data_len)) } pub fn decode(mut buf: impl Buf) -> Result { if Self::LEN > buf.remaining() { return Err(Error::NeedMoreData(Self::LEN - buf.remaining())); } let flags = buf.get_u8(); let len = buf.get_u32(); Ok(Self { flags, data_len: len, }) } pub fn is_end(&self) -> bool { (self.flags & 0b10) != 0 } pub fn is_compressed(&self) -> bool { (self.flags & 0b1) != 0 } pub fn take_data(&self, buf: impl Buf) -> Result { let data_len = self .data_len .try_into() .context("frame data len > usize") .map_err(Error::Other)?; if data_len > buf.remaining() { return Err(Error::NeedMoreData(data_len - buf.remaining())); } Ok(buf.take(data_len)) } pub fn encode(&self, mut dst: impl BufMut) { dst.put_u8(self.flags); dst.put_u32(self.data_len); } }