#![feature(proc_macro_hygiene, decl_macro)] use rocket::http::uri::Uri; use rocket::http::{Header, RawStr}; use rocket::local::{Client, LocalResponse}; use rocket::{get, routes, Rocket}; use rocket_api_base::{AuthError, BaseContent, BaseResponse, NoContent, RocketBasicAuthFairing}; use rocket_contrib::json::Json; use serde_json::Value; use std::fs::File; use std::io::Read; const MISSING_HEADER_FILE: &str = "tests/jwt_auth_handler_missing_header.json"; const INVALID_HEADER_FILE: &str = "tests/jwt_auth_handler_invalid_header.json"; const INVALID_FORMAT_FILE: &str = "tests/basic_auth_handler_invalid_format.json"; const INVALID_CREDENTIALS_FILE: &str = "tests/basic_auth_handler_invalid_credentials.json"; const OK_FILE: &str = "tests/jwt_auth_handler_ok.json"; fn file_to_value(file_name: &'static str) -> Value { let mut file = File::open(file_name).unwrap(); let mut expected_string = String::new(); file.read_to_string(&mut expected_string).unwrap(); serde_json::from_str(&expected_string).unwrap() } fn response_to_value(mut res: LocalResponse) -> Value { serde_json::from_str(&res.body_string().unwrap()).unwrap() } #[get("/test")] fn test_endpoint() -> Json> { Json(BaseResponse:: { result: "ok".to_string(), error_message: "".to_string(), error_code: 0, content: BaseContent::None, }) } #[get("/test2")] fn test2_endpoint() -> Json> { Json(BaseResponse:: { result: "ok".to_string(), error_message: "".to_string(), error_code: 0, content: BaseContent::None, }) } #[get("/other")] fn other_endpoint() -> Json> { Json(BaseResponse:: { result: "ok".to_string(), error_message: "".to_string(), error_code: 0, content: BaseContent::None, }) } #[get("/auth_error?")] fn auth_error_endpoint(message: &RawStr) -> Json> { Json(BaseResponse:: { result: "error".to_string(), error_message: Uri::percent_decode_lossy(message.as_ref()).to_string(), error_code: 403, content: BaseContent::None, }) } fn auth_error_handler(error: AuthError) -> String { "/auth_error?message=".to_string() + &error.get_message_encoded() } const USER: &str = "user"; const PASSWORD: &str = "password"; fn basic_auth_callback(user: String, password: String) -> Result<(), String> { if user.eq(USER) && password.eq(PASSWORD) { Ok(()) } else { Err("/auth_error?message=".to_string() + &Uri::percent_encode("Invalid credentials")) } } fn start_test_server() -> Rocket { rocket::ignite() .attach(RocketBasicAuthFairing::new( basic_auth_callback, auth_error_handler, )) .mount( "/", routes![ test_endpoint, test2_endpoint, other_endpoint, auth_error_endpoint ], ) } fn start_test_server_with_excludes(excludes: Vec<&str>) -> Rocket { rocket::ignite() .attach(RocketBasicAuthFairing::new_with_excludes( basic_auth_callback, auth_error_handler, excludes, )) .mount( "/", routes![ test_endpoint, test2_endpoint, other_endpoint, auth_error_endpoint ], ) } fn start_test_server_with_includes(includes: Vec<&str>) -> Rocket { rocket::ignite() .attach(RocketBasicAuthFairing::new_with_includes( basic_auth_callback, auth_error_handler, includes, )) .mount( "/", routes![ test_endpoint, test2_endpoint, other_endpoint, auth_error_endpoint ], ) } #[test] fn test_no_header() { let rocket = start_test_server(); let client = Client::new(rocket).expect("valid rocket instance"); let req = client.get("/test"); let response = req.dispatch(); assert_eq!( file_to_value(MISSING_HEADER_FILE), response_to_value(response) ); } #[test] fn test_invalid_header() { let rocket = start_test_server(); let client = Client::new(rocket).expect("valid rocket instance"); let mut req = client.get("/test"); req.add_header(Header::new("Authorization", "Bearer bad_scheme")); let response = req.dispatch(); assert_eq!( file_to_value(INVALID_HEADER_FILE), response_to_value(response) ); } #[test] fn test_invalid_format() { let rocket = start_test_server(); let credentials = base64::encode(format!("{}-{}", USER, PASSWORD)); let client = Client::new(rocket).expect("valid rocket instance"); let mut req = client.get("/test"); req.add_header(Header::new( "Authorization", format!("Basic {}", credentials), )); let response = req.dispatch(); assert_eq!( file_to_value(INVALID_FORMAT_FILE), response_to_value(response) ); } #[test] fn test_invalid_credentials() { let rocket = start_test_server(); let credentials = base64::encode("user:bad_password"); let client = Client::new(rocket).expect("valid rocket instance"); let mut req = client.get("/test"); req.add_header(Header::new( "Authorization", format!("Basic {}", credentials), )); let response = req.dispatch(); assert_eq!( file_to_value(INVALID_CREDENTIALS_FILE), response_to_value(response) ); } #[test] fn test_ok() { let rocket = start_test_server(); let credentials = base64::encode(format!("{}:{}", USER, PASSWORD)); let client = Client::new(rocket).expect("valid rocket instance"); let mut req = client.get("/test"); req.add_header(Header::new( "Authorization", format!("Basic {}", credentials), )); let response = req.dispatch(); assert_eq!(file_to_value(OK_FILE), response_to_value(response)); } #[test] fn test_excluded() { let rocket = start_test_server_with_excludes(vec!["/test"]); let client = Client::new(rocket).expect("valid rocket instance"); let req = client.get("/test"); let response = req.dispatch(); assert_eq!(file_to_value(OK_FILE), response_to_value(response)); let req = client.get("/test2"); let response = req.dispatch(); assert_eq!( file_to_value(MISSING_HEADER_FILE), response_to_value(response) ); } #[test] fn test_included() { let rocket = start_test_server_with_includes(vec!["/test"]); let client = Client::new(rocket).expect("valid rocket instance"); let req = client.get("/test2"); let response = req.dispatch(); assert_eq!(file_to_value(OK_FILE), response_to_value(response)); let req = client.get("/test"); let response = req.dispatch(); assert_eq!( file_to_value(MISSING_HEADER_FILE), response_to_value(response) ); } #[test] fn test_excluded_glob() { let rocket = start_test_server_with_excludes(vec!["/test*"]); let client = Client::new(rocket).expect("valid rocket instance"); let req = client.get("/test"); let response = req.dispatch(); assert_eq!(file_to_value(OK_FILE), response_to_value(response)); let req = client.get("/test2"); let response = req.dispatch(); assert_eq!(file_to_value(OK_FILE), response_to_value(response)); let req = client.get("/other"); let response = req.dispatch(); assert_eq!( file_to_value(MISSING_HEADER_FILE), response_to_value(response) ); } #[test] fn test_included_glob() { let rocket = start_test_server_with_includes(vec!["/test*"]); let client = Client::new(rocket).expect("valid rocket instance"); let req = client.get("/test"); let response = req.dispatch(); assert_eq!( file_to_value(MISSING_HEADER_FILE), response_to_value(response) ); let req = client.get("/test2"); let response = req.dispatch(); assert_eq!( file_to_value(MISSING_HEADER_FILE), response_to_value(response) ); let req = client.get("/other"); let response = req.dispatch(); assert_eq!(file_to_value(OK_FILE), response_to_value(response)); }