/* * Licensed to Elasticsearch B.V. under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch B.V. licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. * * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ // From reqwest crate // Licensed under Apache License, Version 2.0 // https://github.com/seanmonstar/reqwest/blob/master/LICENSE-APACHE use std::{convert::identity, net::SocketAddr, sync::mpsc as std_mpsc, thread, time::Duration}; use bytes::Bytes; use http_body_util::Empty; use hyper::{ body::Incoming, server::conn::http1, service::service_fn, HeaderMap, Method, Request, Response, Uri, }; use hyper_util::rt::TokioIo; use opensearch::{http::transport::TransportBuilder, OpenSearch}; use tokio::{ net::{TcpListener, TcpStream}, pin, runtime, select, sync::{mpsc, watch}, task, time::sleep, }; use super::client::TestClientBuilder; #[derive(Clone)] struct RequestState { requests_tx: mpsc::UnboundedSender, response_delay: Option, shutdown_rx: watch::Receiver, } #[derive(Default)] pub struct MockServerBuilder { response_delay: Option, } impl MockServerBuilder { pub fn response_delay(mut self, delay: Duration) -> Self { self.response_delay = Some(delay); self } async fn handle_request( req: Request, state: RequestState, ) -> anyhow::Result>> { state.requests_tx.send(req.into())?; if let Some(response_delay) = state.response_delay { sleep(response_delay).await; } Ok(Default::default()) } async fn serve_connection(io: TokioIo, state: RequestState) { let mut shutdown_rx = state.shutdown_rx.clone(); let conn = http1::Builder::new().serve_connection( io, service_fn(move |req| Self::handle_request(req, state.clone())), ); pin!(conn); select! { _ = conn.as_mut() => {}, _ = shutdown_rx.changed() => conn.as_mut().graceful_shutdown() } } async fn serve(listener: TcpListener, state: RequestState) -> anyhow::Result<()> { let mut shutdown_rx = state.shutdown_rx.clone(); loop { let (stream, _) = tokio::select! { res = listener.accept() => res?, _ = shutdown_rx.changed() => break }; let io = TokioIo::new(stream); task::spawn(Self::serve_connection(io, state.clone())); } Ok(()) } fn start_inner(self, thread_name: String) -> anyhow::Result { let rt = runtime::Builder::new_current_thread() .enable_all() .build()?; let _ = rt.enter(); let (shutdown_tx, shutdown_rx) = watch::channel(false); let (requests_tx, requests_rx) = mpsc::unbounded_channel(); let listener = rt.block_on(TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))))?; let addr = listener.local_addr()?; let srv = Self::serve( listener, RequestState { requests_tx, response_delay: self.response_delay, shutdown_rx, }, ); let (panic_tx, panic_rx) = std_mpsc::channel(); thread::Builder::new() .name(format!("test({})-support-server", thread_name)) .spawn(move || { rt.block_on(srv).unwrap(); let _ = panic_tx.send(()); })?; Ok(MockServer { uri: format!("http://{}", addr), requests_rx, panic_rx, shutdown_tx: Some(shutdown_tx), }) } pub fn start(self) -> anyhow::Result { let thread_name = thread::current().name().unwrap_or("").to_owned(); match thread::spawn(move || self.start_inner(thread_name)).join() { Ok(r) => r, Err(e) => Err(anyhow::anyhow!("MockServer startup panicked: {:?}", e)), } } } pub struct MockServer { uri: String, requests_rx: mpsc::UnboundedReceiver, panic_rx: std_mpsc::Receiver<()>, shutdown_tx: Option>, } impl MockServer { pub fn builder() -> MockServerBuilder { MockServerBuilder::default() } pub fn start() -> anyhow::Result { Self::builder().start() } pub fn client(&self) -> OpenSearch { self.client_with(identity) } pub fn client_with( &self, configurator: impl FnOnce(TransportBuilder) -> TransportBuilder, ) -> OpenSearch { self.client_builder().with(configurator).build() } pub fn client_builder(&self) -> TestClientBuilder { super::client::builder_with_url(&self.uri) } pub async fn received_request(&mut self) -> anyhow::Result { self.requests_rx .recv() .await .ok_or_else(|| anyhow::anyhow!("no request received")) } } impl Drop for MockServer { fn drop(&mut self) { if let Some(tx) = self.shutdown_tx.take() { tx.send(true).unwrap(); } if !::std::thread::panicking() { self.panic_rx .recv_timeout(Duration::from_secs(3)) .expect("test server should not panic"); } } } pub struct ReceivedRequest { method: Method, uri: Uri, headers: HeaderMap, } impl ReceivedRequest { pub fn method(&self) -> &Method { &self.method } pub fn path(&self) -> &str { self.uri.path() } pub fn query(&self) -> Option<&str> { self.uri.query() } pub fn header(&self, name: &str) -> Option<&str> { self.headers.get(name).and_then(|v| v.to_str().ok()) } } impl From> for ReceivedRequest { fn from(req: Request) -> Self { ReceivedRequest { method: req.method().clone(), uri: req.uri().clone(), headers: req.headers().clone(), } } }