use hyperdrive::{ body::Json, http::{Method, Request, StatusCode}, hyper::Body, BoxedError, Error, FromRequest, Guard, NoContext, RequestContext, }; use serde::Deserialize; use std::str::FromStr; use std::sync::Arc; /// Simulates receiving `request`, and decodes a `FromRequest` implementor `T`. /// /// `T` has to take a `NoContext`. fn invoke(request: Request) -> Result where T: FromRequest, { T::from_request_sync(request, NoContext) } fn invoke_with(request: Request, context: T::Context) -> Result where T: FromRequest, { T::from_request_sync(request, context) } #[derive(Debug, PartialEq, Eq)] struct MyGuard; impl Guard for MyGuard { type Context = NoContext; type Result = Result; fn from_request(_request: &Arc>, _context: &Self::Context) -> Self::Result { Ok(MyGuard) } } /// A few demo routes for user management (login, user info, user edit). #[test] fn user_app() { #[derive(FromRequest, Debug)] #[allow(dead_code)] enum Routes { #[post("/login")] Login { #[body] data: Json, #[query_params] params: (), gourd: MyGuard, }, #[get("/users/{id}")] User { id: u32 }, #[patch("/users/{id}")] PatchUser { id: u32, #[body] data: Json, }, } #[derive(Deserialize, Debug)] #[allow(dead_code)] struct LoginData { email: String, password: String, } #[derive(Deserialize, Debug)] #[serde(untagged)] #[allow(dead_code)] enum PatchUser { General { display_name: String, }, ChangePassword { old_password: String, new_password: String, }, } let login = invoke::( Request::post("/login") .body( r#" { "email": "test@example.com", "password": "hunter2" } "# .into(), ) .unwrap(), ) .expect("/login not routed properly"); match login { Routes::Login { params: (), gourd: MyGuard, data: Json(body), } => { assert_eq!(body.email, "test@example.com"); assert_eq!(body.password, "hunter2"); } _ => panic!("unexpected result: {:?}", login), } let get_login = invoke::(Request::get("/login").body(Body::empty()).unwrap()); let error: Box = get_login.unwrap_err().downcast().unwrap(); assert_eq!(error.http_status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!( error.allowed_methods().expect("allowed_methods()"), &[&Method::POST] ); let post_user = invoke::(Request::post("/users/0").body(Body::empty()).unwrap()); let error: Box = post_user.unwrap_err().downcast().unwrap(); assert_eq!(error.http_status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!( error.allowed_methods().expect("allowed_methods()"), &[&Method::GET, &Method::PATCH, &Method::HEAD] ); let user = invoke::(Request::get("/users/wrong").body(Body::empty()).unwrap()); let error: Box = user.unwrap_err().downcast().unwrap(); assert_eq!(error.http_status(), StatusCode::NOT_FOUND); } /// Tests that `#[context]` can be used to change the context accepted by the /// `FromRequest` impl. It should still be possible to use guards that take a /// `NoContext` instead. #[test] fn context() { #[derive(FromRequest, Debug)] #[context(SpecialContext)] enum Routes { #[get("/")] Variant { /// Takes a `SpecialContext`. special: SpecialGuard, /// Takes a `NoContext`. normal: MyGuard, }, } #[derive(RequestContext, Debug)] struct SpecialContext; #[derive(Debug)] struct SpecialGuard; impl Guard for SpecialGuard { type Context = SpecialContext; type Result = Result; fn from_request( _request: &Arc>, _context: &Self::Context, ) -> Self::Result { Ok(SpecialGuard) } } invoke_with::( Request::get("/").body(Body::empty()).unwrap(), SpecialContext, ) .unwrap(); invoke_with::( Request::get("/bla").body(Body::empty()).unwrap(), SpecialContext, ) .unwrap_err(); } #[test] fn struct_context() { #[derive(FromRequest, Debug)] #[context(SpecialContext)] #[get("/")] struct Route { /// Takes a `SpecialContext`. special: SpecialGuard, /// Takes a `NoContext`. normal: MyGuard, } #[derive(RequestContext, Debug)] struct SpecialContext; #[derive(Debug)] struct SpecialGuard; impl Guard for SpecialGuard { type Context = SpecialContext; type Result = Result; fn from_request( _request: &Arc>, _context: &Self::Context, ) -> Self::Result { Ok(SpecialGuard) } } invoke_with::( Request::get("/").body(Body::empty()).unwrap(), SpecialContext, ) .unwrap(); invoke_with::( Request::get("/bla").body(Body::empty()).unwrap(), SpecialContext, ) .unwrap_err(); } #[test] fn any_placeholder() { #[derive(FromRequest, Debug, PartialEq, Eq)] enum Routes { #[get("/{ph}/{rest...}")] Variant { ph: u32, rest: String }, } let route = invoke::( Request::get("/1234/bla/bli?param=123") .body(Body::empty()) .unwrap(), ) .unwrap(); assert_eq!( route, Routes::Variant { ph: 1234, rest: "bla/bli".to_string() } ); let route = invoke::(Request::get("/1234/").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Routes::Variant { ph: 1234, rest: "".to_string() } ); invoke::(Request::get("/1234").body(Body::empty()).unwrap()).unwrap_err(); } #[test] fn asterisk() { #[derive(FromRequest, Debug)] enum Routes { #[options("*")] ServerOptions, } invoke::(Request::options("*").body(Body::empty()).unwrap()).unwrap(); invoke::(Request::options("/").body(Body::empty()).unwrap()).unwrap_err(); invoke::(Request::head("/").body(Body::empty()).unwrap()).unwrap_err(); #[derive(FromRequest, Debug)] #[options("*")] struct Options; invoke::(Request::options("*").body(Body::empty()).unwrap()).unwrap(); invoke::(Request::options("/").body(Body::empty()).unwrap()).unwrap_err(); invoke::(Request::head("/").body(Body::empty()).unwrap()).unwrap_err(); } #[test] fn implicit_head_route() { #[derive(FromRequest, Debug, PartialEq, Eq)] enum Routes { #[get("/")] Index, #[get("/2/other")] Other, // We should still be able to define our own HEAD route instead #[head("/2/other")] OtherHead, } let head = invoke::(Request::head("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!(head, Routes::Index); let anyhead = invoke::(Request::head("/2/other").body(Body::empty()).unwrap()).unwrap(); assert_eq!(anyhead, Routes::OtherHead); let anyhead = invoke::(Request::get("/2/other").body(Body::empty()).unwrap()).unwrap(); assert_eq!(anyhead, Routes::Other); } #[test] fn query_params() { #[derive(FromRequest, PartialEq, Eq, Debug)] enum Routes { #[get("/users")] UserList { #[query_params] pagination: Pagination, }, } #[derive(Deserialize, PartialEq, Eq, Debug)] struct Pagination { #[serde(default)] start: u32, #[serde(default = "default_count")] count: u32, } fn default_count() -> u32 { 10 } let route = invoke::(Request::get("/users").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Routes::UserList { pagination: Pagination { start: 0, count: 10, } } ); let route = invoke::(Request::get("/users?count=30").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Routes::UserList { pagination: Pagination { start: 0, count: 30, } } ); let route = invoke::( Request::get("/users?start=543") .body(Body::empty()) .unwrap(), ) .unwrap(); assert_eq!( route, Routes::UserList { pagination: Pagination { start: 543, count: 10, } } ); let route = invoke::( Request::get("/users?start=123&count=30") .body(Body::empty()) .unwrap(), ) .unwrap(); assert_eq!( route, Routes::UserList { pagination: Pagination { start: 123, count: 30, } } ); } /// Tests that the derive works on generic enums and structs. #[test] fn generic() { #[derive(FromRequest, Debug, PartialEq, Eq)] enum Routes { #[get("/{path}")] OmniRoute { path: U, #[query_params] qp: Q, #[body] body: B, guard: G, }, } #[derive(RequestContext, Debug)] struct SpecialContext; #[derive(FromRequest, Debug, PartialEq, Eq)] #[get("/{path}")] #[context(SpecialContext)] struct Struct { path: U, #[query_params] qp: Q, #[body] body: B, guard: G, } #[derive(PartialEq, Eq, Debug)] struct SpecialGuard; impl Guard for SpecialGuard { type Context = SpecialContext; type Result = Result; fn from_request(_request: &Arc>, _context: &Self::Context) -> Self::Result { Ok(SpecialGuard) } } #[derive(Deserialize, PartialEq, Eq, Debug)] struct Pagination { start: u32, count: u32, } #[derive(Deserialize, PartialEq, Eq, Debug)] struct LoginData { email: String, password: String, } let url = "/users?start=123&count=30"; let body = r#" { "email": "test@example.com", "password": "hunter2" } "#; let route: Routes, MyGuard> = invoke(Request::get(url).body(body.into()).unwrap()).unwrap(); assert_eq!( route, Routes::OmniRoute { path: "users".to_string(), qp: Pagination { start: 123, count: 30 }, body: Json(LoginData { email: "test@example.com".to_string(), password: "hunter2".to_string() }), guard: MyGuard, } ); // Make sure the `SpecialContext` is turned into whatever context is needed by the fields, and // that we have the right where-clauses for it let route: Struct, MyGuard> = invoke_with(Request::get(url).body(body.into()).unwrap(), SpecialContext).unwrap(); assert_eq!( route, Struct { path: "users".to_string(), qp: Pagination { start: 123, count: 30 }, body: Json(LoginData { email: "test@example.com".to_string(), password: "hunter2".to_string() }), guard: MyGuard, } ); // A guard that needs a `SpecialContext` must also work: let _route: Struct, SpecialGuard> = invoke_with(Request::get(url).body(body.into()).unwrap(), SpecialContext).unwrap(); } #[test] fn forward() { #[derive(FromRequest, PartialEq, Eq, Debug)] enum Inner { #[get("/")] Index, #[get("/flabberghast")] Flabberghast, #[post("/")] Post, } #[derive(FromRequest, PartialEq, Eq, Debug)] #[get("/")] // FIXME: forbid this? struct Req { #[forward] _inner: Inner, } #[derive(FromRequest, PartialEq, Eq, Debug)] enum Enum { #[get("/")] First { #[forward] _inner: Inner, }, Second { #[forward] _inner: Inner, }, } invoke::(Request::get("/").body(Body::empty()).unwrap()).unwrap(); let route = invoke::(Request::get("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Enum::First { _inner: Inner::Index }, "GET /" ); let route = invoke::(Request::head("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Enum::First { _inner: Inner::Index }, "HEAD /" ); let route = invoke::(Request::get("/flabberghast").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Enum::Second { _inner: Inner::Flabberghast }, "GET /flabberghast" ); let route = invoke::(Request::post("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Enum::Second { _inner: Inner::Post }, "POST /" ); } /// Tests that invalid methods return the right set of allowed methods, even in the presence of /// `#[forward]`. #[test] fn forward_allowed_methods() { #[derive(FromRequest, PartialEq, Eq, Debug)] enum Inner { #[get("/")] #[post("/")] Index, #[get("/customhead")] GetCustomHead, #[head("/customhead")] HeadCustomHead, #[post("/post")] Post, #[post("/shared")] Shared, #[post("/shared/{s}")] Shared2 { s: u32 }, } #[derive(FromRequest, PartialEq, Eq, Debug)] enum Wrapper { #[get("/shared")] Shared, #[get("/shared/{s}")] Shared2 { s: u8 }, Fallback { #[forward] inner: Inner, }, } #[derive(PartialEq, Eq, Debug)] struct AlwaysErr; impl FromStr for AlwaysErr { type Err = BoxedError; fn from_str(_: &str) -> Result { Err(String::new().into()) } } let err: Box = invoke::(Request::get("/post").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!( err.allowed_methods().expect("allowed_methods()"), &[&Method::POST] ); let err: Box = invoke::(Request::post("/customhead").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!( err.allowed_methods().expect("allowed_methods()"), &[&Method::GET, &Method::HEAD] ); // `/shared` is defined in both. Outer takes precedence over inner, if it matches. let route = invoke::(Request::post("/shared").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Wrapper::Fallback { inner: Inner::Shared } ); let route = invoke::(Request::get("/shared").body(Body::empty()).unwrap()).unwrap(); assert_eq!(route, Wrapper::Shared); // Methods not accepted by either result in `allowed_methods()` being merged together. let err: Box = invoke::(Request::put("/shared").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!( err.allowed_methods().expect("allowed_methods()"), &[&Method::GET, &Method::HEAD, &Method::POST] ); // Also with FromStr segments let route = invoke::(Request::post("/shared/123").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Wrapper::Fallback { inner: Inner::Shared2 { s: 123 } } ); let route = invoke::(Request::get("/shared/123").body(Body::empty()).unwrap()).unwrap(); assert_eq!(route, Wrapper::Shared2 { s: 123 }); } #[test] fn generic_forward() { #[derive(FromRequest, Debug, PartialEq, Eq)] enum Generic { #[get("/unused")] Unused, Fallback { guard: G, #[forward] inner: I, }, } #[derive(FromRequest, Debug, PartialEq, Eq)] enum Inner { #[get("/")] Index, } let route: Generic = invoke(Request::get("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Generic::Fallback { guard: MyGuard, inner: Inner::Index } ); } #[test] fn generic_guard_struct() { #[derive(FromRequest, Debug, PartialEq, Eq)] struct Generic { guard: G, #[forward] inner: I, } #[derive(FromRequest, Debug, PartialEq, Eq)] enum Inner { #[get("/")] Index, } let route: Generic = invoke(Request::get("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!( route, Generic { guard: MyGuard, inner: Inner::Index } ); let err: Box = invoke::>(Request::get("/notfound").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::NOT_FOUND); let err: Box = invoke::>(Request::post("/").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::METHOD_NOT_ALLOWED); } #[test] fn generic_guard_struct_2() { #[derive(FromRequest, Debug, PartialEq, Eq)] struct Generic { guard: G, #[forward] inner: Inner, } #[derive(FromRequest, Debug, PartialEq, Eq)] enum Inner { #[get("/")] Index, } let err: Box = invoke::>(Request::get("/notfound").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::NOT_FOUND); let err: Box = invoke::>(Request::post("/").body(Body::empty()).unwrap()) .unwrap_err() .downcast() .unwrap(); assert_eq!(err.http_status(), StatusCode::METHOD_NOT_ALLOWED); } /// Keeps another `Arc` around pointing to the request, while the `#[forward]`ed `from_request` is /// invoked. /// /// This will invoke the slow-path that manually clones the request, since `Arc::try_unwrap` will /// now fail. #[test] fn klepto_arc() { struct MyGuard { request: Arc>, } impl Guard for MyGuard { type Context = NoContext; type Result = Result; fn from_request(request: &Arc>, _: &Self::Context) -> Self::Result { Ok(MyGuard { request: Arc::clone(request), }) } } #[derive(FromRequest)] #[get("/")] struct Route { guard: MyGuard, #[forward] _inner: Inner, } #[derive(FromRequest)] #[get("/")] struct Inner {} let route: Route = invoke(Request::get("/").body(Body::empty()).unwrap()).unwrap(); assert_eq!(route.guard.request.uri(), "/"); assert_eq!(route.guard.request.method(), "GET"); }