use self::support::{into_text, serve}; use hyper::{Body, Client, Request, Response, StatusCode}; use routerify::prelude::RequestExt; use routerify::{Middleware, RequestInfo, RouteError, Router}; use std::io; use std::sync::{Arc, Mutex}; mod support; #[tokio::test] async fn can_perform_simple_get_request() { const RESPONSE_TEXT: &str = "Hello world"; let router: Router = Router::builder() .get("/", |_| async move { Ok(Response::new(RESPONSE_TEXT.into())) }) .err_handler(|_: RouteError| async move { todo!() }) .build() .unwrap(); let serve = serve(router).await; let resp = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let resp = into_text(resp.into_body()).await; assert_eq!(resp, RESPONSE_TEXT.to_owned()); serve.shutdown(); } #[tokio::test] async fn can_perform_simple_get_request_boxed_error() { const RESPONSE_TEXT: &str = "Hello world"; type BoxedError = Box; let router: Router = Router::builder() .get("/", |_| async move { Ok(Response::new(RESPONSE_TEXT.into())) }) .err_handler(|_: RouteError| async move { todo!() }) .build() .unwrap(); let serve = serve(router).await; let resp = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let resp = into_text(resp.into_body()).await; assert_eq!(resp, RESPONSE_TEXT.to_owned()); serve.shutdown(); } #[tokio::test] async fn can_respond_with_data_from_scope_state() { // Creating two modules containing separate state and routes which expose that state directly... mod service1 { use super::*; struct State { count: Arc>, } async fn list(req: Request) -> Result, io::Error> { let count = req.data::().unwrap().count.lock().unwrap(); Ok(Response::new(Body::from(format!("{}", count)))) } pub fn router() -> Router { let state = State { count: Arc::new(Mutex::new(1)), }; Router::builder().data(state).get("/", list).build().unwrap() } } mod service2 { use super::*; struct State { count: Arc>, } async fn list(req: Request) -> Result, io::Error> { let count = req.data::().unwrap().count.lock().unwrap(); Ok(Response::new(Body::from(format!("{}", count)))) } pub fn router() -> Router { let state = State { count: Arc::new(Mutex::new(2)), }; Router::builder().data(state).get("/", list).build().unwrap() } } let router = Router::builder() .scope( "/v1", Router::builder() .scope("/service1", service1::router()) .scope("/service2", service2::router()) .build() .unwrap(), ) .build() .unwrap(); let serve = serve(router).await; // Ensure response contains service1's unique data. let resp = Client::new() .request(serve.new_request("GET", "/v1/service1").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(200, resp.status().as_u16()); assert_eq!("1", into_text(resp.into_body()).await); // Ensure response contains service2's unique data. let resp = Client::new() .request(serve.new_request("GET", "/v1/service2").body(Body::empty()).unwrap()) .await .unwrap(); assert_eq!(200, resp.status().as_u16()); assert_eq!(into_text(resp.into_body()).await, "2"); serve.shutdown(); } #[tokio::test] async fn can_propagate_request_context() { use std::io; #[derive(Debug, Clone, PartialEq)] struct Id(u32); #[derive(Debug, Clone, PartialEq)] struct Id2(u32); let before = |req: Request| async move { req.set_context(Id(42)); let (parts, body) = req.into_parts(); parts.set_context(Id2(42)); Ok(Request::from_parts(parts, body)) }; let index = |req: Request| async move { // Check `id` from `before()`. let id = req.context::().unwrap(); assert_eq!(id, Id(42)); // Check that non-existent context value is None. let none = req.context::(); assert!(none.is_none()); // Add a String value to the context. req.set_context("index".to_string()); let (parts, _) = req.into_parts(); // Check `id2` from `before()`. let id2 = parts.context::().unwrap(); assert_eq!(id2, Id2(42)); // Update the Id2 value in the context. parts.set_context(Id2(1)); // Trigger this error in order to invoke // the error handler. Err(io::Error::new(io::ErrorKind::AddrInUse, "bogus error")) }; let error_handler = |_err, req_info: RequestInfo| async move { // Check `id` from `before()`. let id = req_info.context::().unwrap(); assert_eq!(id, Id(42)); // Check String from `index()`. let name = req_info.context::().unwrap(); assert_eq!(name, "index"); // Check updated `id2` from `index()`. let id2 = req_info.context::().unwrap(); assert_eq!(id2, Id2(1)); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from("Something went wrong")) .unwrap() }; let after = |res, req_info: RequestInfo| async move { // Check `id` from `before()`. let id = req_info.context::().unwrap(); assert_eq!(id, Id(42)); // Check String from `index()`. let name = req_info.context::().unwrap(); assert_eq!(name, "index"); // Check updated `id2` from `index()`. let id2 = req_info.context::().unwrap(); assert_eq!(id2, Id2(1)); Ok(res) }; let router: Router = Router::builder() .middleware(Middleware::pre(before)) .middleware(Middleware::post_with_info(after)) .err_handler_with_info(error_handler) .get("/", index) .build() .unwrap(); let serve = serve(router).await; let _ = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); serve.shutdown(); } #[tokio::test] async fn can_extract_path_params() { const RESPONSE_TEXT: &str = "Hello world"; let router: Router = Router::builder() .get("/api/:first/plus/:second", |req| async move { let first = req.param("first").unwrap(); let second = req.param("second").unwrap(); assert_eq!(first, "40"); assert_eq!(second, "2"); let (parts, _) = req.into_parts(); let first = parts.param("first").unwrap(); let second = parts.param("second").unwrap(); assert_eq!(first, "40"); assert_eq!(second, "2"); Ok(Response::new(RESPONSE_TEXT.into())) }) .build() .unwrap(); let serve = serve(router).await; let resp = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/api/40/plus/2", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let resp = into_text(resp.into_body()).await; assert_eq!(resp, RESPONSE_TEXT.to_owned()); serve.shutdown(); } #[tokio::test] async fn can_extract_extension_path_params_1() { const RESPONSE_TEXT: &str = "Hello world"; let router: Router = Router::builder() .get("/api/:id.json", |req| async move { let id = req.param("id").unwrap(); assert_eq!(id, "40"); let (parts, _) = req.into_parts(); let id = parts.param("id").unwrap(); assert_eq!(id, "40"); Ok(Response::new(RESPONSE_TEXT.into())) }) .build() .unwrap(); let serve = serve(router).await; let resp = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/api/40.json", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let resp = into_text(resp.into_body()).await; assert_eq!(resp, RESPONSE_TEXT.to_owned()); serve.shutdown(); } #[tokio::test] async fn can_extract_extension_path_params_2() { const RESPONSE_TEXT: &str = "Hello world"; let router: Router = Router::builder() .get("/api/:fileName", |req| async move { let file_name = req.param("fileName").unwrap(); assert_eq!(file_name, "data.json"); let (parts, _) = req.into_parts(); let file_name = parts.param("fileName").unwrap(); assert_eq!(file_name, "data.json"); Ok(Response::new(RESPONSE_TEXT.into())) }) .build() .unwrap(); let serve = serve(router).await; let resp = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/api/data.json", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); let resp = into_text(resp.into_body()).await; assert_eq!(resp, RESPONSE_TEXT.to_owned()); serve.shutdown(); } #[tokio::test] async fn do_not_execute_scoped_middleware_for_unscoped_path() { let api_router: Router = Router::builder() .middleware(Middleware::pre(|_| async { panic!("should not be executed") })) .middleware(Middleware::post(|_| async { panic!("should not be executed") })) .get("/api/todo", |_| async { Ok(Response::new("".into())) }) .build() .unwrap(); let router: Router = Router::builder() .get("/", |_| async { Ok(Response::new("".into())) }) .scope("/api", api_router) .get("/api/login", |_| async { Ok(Response::new("".into())) }) .build() .unwrap(); let serve = serve(router).await; let _ = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/api/login", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); serve.shutdown(); } #[tokio::test] async fn execute_scoped_middleware_when_no_unscoped_match() { use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use std::sync::Arc; struct ExecPre(AtomicBool); struct ExecPost(AtomicBool); let executed_pre = Arc::new(ExecPre(AtomicBool::new(false))); let executed_post = Arc::new(ExecPost(AtomicBool::new(false))); // Record the execution of pre and post middleware. let api_router: Router = Router::builder() .middleware(Middleware::pre(|req| async { let pre = req.data::>().unwrap(); pre.0.store(true, SeqCst); Ok(req) })) .middleware(Middleware::pre(|req| async { let post = req.data::>().unwrap(); post.0.store(true, SeqCst); Ok(req) })) .get("/api/todo", |_| async { Ok(Response::new("".into())) }) .build() .unwrap(); let router: Router = Router::builder() .data(executed_pre.clone()) .data(executed_post.clone()) .get("/", |_| async { Ok(Response::new("".into())) }) .scope("/api", api_router) .get("/api/login", |_| async { Ok(Response::new("".into())) }) .build() .unwrap(); let serve = serve(router).await; let _ = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/api/nomatch", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert!(executed_pre.0.load(SeqCst)); assert!(executed_post.0.load(SeqCst)); serve.shutdown(); } #[tokio::test] async fn can_handle_custom_errors() { #[derive(Debug)] enum ApiError { Generic(String), } impl std::error::Error for ApiError {} impl std::fmt::Display for ApiError { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { ApiError::Generic(s) => write!(f, "Generic: {}", s), } } } const RESPONSE_TEXT: &str = "Something went wrong!"; let router: Router = Router::builder() .get("/", |_| async move { Err(ApiError::Generic(RESPONSE_TEXT.into())) }) .err_handler(|err: RouteError| async move { let api_err = err.downcast::().unwrap(); match api_err.as_ref() { ApiError::Generic(s) => Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(s.to_string())) .unwrap(), } }) .build() .unwrap(); let serve = serve(router).await; let resp = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}/", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); let resp = into_text(resp.into_body()).await; assert_eq!(resp, RESPONSE_TEXT.to_owned()); serve.shutdown(); } #[tokio::test] async fn can_handle_pre_middleware_errors() { struct State {} #[derive(Clone)] struct Ctx(i32); let state = State {}; // If pre middleware fails, then `data` and `req.context` should // propagate to the error handler and post middleware. The route // handler should not be executed. let router: Router = Router::builder() .data(state) .middleware(Middleware::pre(|req| async move { req.set_context(Ctx(42)); Err(routerify::Error::new("Error!")) })) .err_handler_with_info(|err, req_info| async move { let _ctx = req_info.context::().expect("No Ctx"); let _state = req_info.data::().expect("No state"); Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(Body::from(err.to_string())) .unwrap() }) .middleware(Middleware::post_with_info(|resp, req_info| async move { let _ctx = req_info.context::().expect("No Ctx"); let _state = req_info.data::().expect("No state"); Ok(resp) })) .get("/", |_| async { panic!("should not be executed") }) .build() .unwrap(); let serve = serve(router).await; let _ = Client::new() .request( Request::builder() .method("GET") .uri(format!("http://{}", serve.addr())) .body(Body::empty()) .unwrap(), ) .await .unwrap(); serve.shutdown(); }