use anyhow::Context; use futures_core::Stream; use futures_util::StreamExt; use http::{header::CONTENT_TYPE, Extensions, HeaderMap, StatusCode}; use reqwest::{Client, IntoUrl, Url}; use tonic::{metadata::MetadataMap, Request, Response}; use crate::{ codec::{ConnectCodec, ConnectDecoder, ConnectEncoder}, error::Error, request::RequestEnvelope, status::ConnectStatus, stream::Frame, }; use super::ConnectChannel; pub struct ReqwestChannel { endpoint: Url, client: Client, } impl ReqwestChannel { pub fn new(endpoint: impl IntoUrl) -> reqwest::Result { Self::with_client(endpoint, Default::default()) } pub fn with_client(endpoint: impl IntoUrl, client: Client) -> reqwest::Result { Ok(Self { endpoint: endpoint.into_url()?, client, }) } fn method_url(&self, method: &str) -> Result { self.endpoint .join(method) .context("invalid method") .map_err(Error::ChannelError) } fn request_builder( &self, method: &str, envelope: RequestEnvelope, codec: &impl ConnectCodec, ) -> Result { let RequestEnvelope { metadata, is_streaming, timeout, .. } = envelope; let url = self.method_url(method)?; let mut builder = self .client .post(url) .headers(metadata.into_headers()) .header( CONTENT_TYPE, codec.message_codec().content_type(is_streaming), ); if let Some(timeout) = timeout { builder = builder.timeout(*timeout); } Ok(builder) } } impl ConnectChannel for ReqwestChannel { async fn unary_call( &self, method: &str, request: Request, codec: &mut impl ConnectCodec, ) -> Result, Error> { let req = { let (envelope, message) = RequestEnvelope::from_tonic_request(request, false)?; let body = codec .encoder() .encode_to_bytes(message) .context("error encoding message body") .map_err(Error::CodecError)?; self.request_builder(method, envelope, codec)?.body(body) }; let mut resp = req.send().await.map_err(reqwest_error)?; let (status, metadata, resp_extensions) = take_response_parts(&mut resp); let body = resp.bytes().await.map_err(reqwest_error)?; if !status.is_success() { return Err(ConnectStatus::from_connect_response(status, body) .unwrap() .into_error(metadata)); } let resp_message = codec.decoder().decode(body)?; Ok(Response::from_parts( metadata, resp_message, resp_extensions, )) } async fn streaming_call( &self, method: &str, request: Request> + Send>, codec: &mut impl ConnectCodec, ) -> Result>>, Error> { let req = { let (envelope, stream) = RequestEnvelope::from_tonic_request(request, true)?; let encoder = codec.encoder(); let body = reqwest::Body::wrap_stream(stream.map(|frame| frame.encode(&mut encoder))); self.request_builder(method, envelope, codec)?.body(body) }; let mut resp = req.send().await.map_err(reqwest_error)?; let (status, metadata, resp_extensions) = take_response_parts(&mut resp); if !status.is_success() { let resp_body = resp.bytes().await.map_err(reqwest_error)?; return Err(ConnectStatus::from_connect_response(status, resp_body) .unwrap() .into_error(metadata)); } // resp.bytes_stream().map_err(reqwest_error) Err(Error::TimedOut) } } fn take_response_parts(resp: &mut reqwest::Response) -> (StatusCode, MetadataMap, Extensions) { let metadata = MetadataMap::from_headers(std::mem::take(resp.headers_mut())); let extensions = std::mem::take(resp.extensions_mut()); (resp.status(), metadata, extensions) } fn reqwest_error(err: reqwest::Error) -> Error { if err.is_timeout() { Error::status( crate::status::ConnectCode::DeadlineExceeded, "timeout elapsed", ) } else { Error::ChannelError(err.into()) } }