#![allow(dead_code)] // REF: https://github.com/tokio-rs/axum/blob/main/axum/src/test_helpers/test_client.rs use axum::{extract::Request, response::Response, serve}; use bytes::Bytes; use http::{ header::{HeaderName, HeaderValue}, StatusCode, }; use std::{convert::Infallible, net::SocketAddr, str::FromStr}; use tokio::net::TcpListener; use tower::make::Shared; use tower_service::Service; pub(crate) struct TestClient { client: reqwest::Client, addr: SocketAddr, } impl TestClient { pub(crate) fn new(svc: S) -> Self where S: Service + Clone + Send + 'static, S::Future: Send, { let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); std_listener.set_nonblocking(true).unwrap(); let listener = TcpListener::from_std(std_listener).unwrap(); let addr = listener.local_addr().unwrap(); println!("Listening on {addr}"); tokio::spawn(async move { serve(listener, Shared::new(svc)) .await .expect("server error") }); let client = reqwest::Client::builder() .redirect(reqwest::redirect::Policy::none()) .build() .unwrap(); TestClient { client, addr } } pub(crate) fn get(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.get(format!("http://{}{}", self.addr, url)), } } pub(crate) fn head(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.head(format!("http://{}{}", self.addr, url)), } } pub(crate) fn post(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.post(format!("http://{}{}", self.addr, url)), } } pub(crate) fn put(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.put(format!("http://{}{}", self.addr, url)), } } pub(crate) fn patch(&self, url: &str) -> RequestBuilder { RequestBuilder { builder: self.client.patch(format!("http://{}{}", self.addr, url)), } } } pub(crate) struct RequestBuilder { builder: reqwest::RequestBuilder, } impl RequestBuilder { pub(crate) async fn send(self) -> TestResponse { TestResponse { response: self.builder.send().await.unwrap(), } } pub(crate) fn body(mut self, body: impl Into) -> Self { self.builder = self.builder.body(body); self } pub(crate) fn json(mut self, json: &T) -> Self where T: serde::Serialize, { self.builder = self.builder.json(json); self } pub(crate) fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { // reqwest still uses http 0.2 let key: HeaderName = key.try_into().map_err(Into::into).unwrap(); let key = reqwest::header::HeaderName::from_bytes(key.as_ref()).unwrap(); let value: HeaderValue = value.try_into().map_err(Into::into).unwrap(); let value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(); self.builder = self.builder.header(key, value); self } #[allow(dead_code)] pub(crate) fn multipart(mut self, form: reqwest::multipart::Form) -> Self { self.builder = self.builder.multipart(form); self } } #[derive(Debug)] pub(crate) struct TestResponse { response: reqwest::Response, } impl TestResponse { pub(crate) async fn bytes(self) -> Bytes { self.response.bytes().await.unwrap() } pub(crate) async fn text(self) -> String { self.response.text().await.unwrap() } pub(crate) async fn json(self) -> T where T: serde::de::DeserializeOwned, { self.response.json().await.unwrap() } pub(crate) fn status(&self) -> StatusCode { StatusCode::from_u16(self.response.status().as_u16()).unwrap() } pub(crate) fn headers(&self) -> http::HeaderMap { // reqwest still uses http 0.2 so have to convert into http 1.0 let mut headers = http::HeaderMap::new(); for (key, value) in self.response.headers() { let key = http::HeaderName::from_str(key.as_str()).unwrap(); let value = http::HeaderValue::from_bytes(value.as_bytes()).unwrap(); headers.insert(key, value); } headers } pub(crate) async fn chunk(&mut self) -> Option { self.response.chunk().await.unwrap() } pub(crate) async fn chunk_text(&mut self) -> Option { let chunk = self.chunk().await?; Some(String::from_utf8(chunk.to_vec()).unwrap()) } }