use kvarn::prelude::*; use kvarn_testing::prelude::*; #[tokio::test] async fn default_deny() { let server = ServerBuilder::default().run().await; let response = server .get("/") .header("origin", "https://kvarn.org") .send() .await .unwrap(); test_cors_response(response, false, line!()).await; } #[tokio::test] async fn default_options() { let server = ServerBuilder::default().run().await; test_cors_options( &server, "/", "https://doc.kvarn.org", &[Method::POST], &[], false, false, line!(), ) .await; } fn get_extensions() -> Extensions { let mut extensions = Extensions::new(); let cors = Cors::empty() .add("/logo.svg", CorsAllowList::default().allow_all_origins()) .add( "/api/*", CorsAllowList::default() .add_origin("https://icelk.dev") .add_origin("http://kvarn.org") .add_origin("https://kvarn.org") .add_method(Method::PUT) .add_method(Method::DELETE), ) .add( "/images/*", CorsAllowList::new(Duration::from_secs(60 * 60 * 24 * 365)) .add_origin("https://example.org") .add_origin("https://foo.bar"), ) .arc(); extensions.with_cors(cors); extensions } async fn test_cors_response(response: reqwest::Response, valid_expected: bool, line: u32) { if valid_expected { assert_eq!( response.status(), reqwest::StatusCode::NO_CONTENT, "On line {} Response: {:#?}", line, response ); } else { assert_eq!( response.status(), reqwest::StatusCode::FORBIDDEN, "On line {} Response: {:#?}", line, response ); assert_eq!( response.text().await.unwrap(), "CORS request denied", "On line {}", line ); } } #[allow(clippy::too_many_arguments)] async fn test_cors_options( server: &Server, path: impl AsRef, origin: impl AsRef, methods: &[Method], headers: &[&str], valid_expected: bool, test_methods_and_headers: bool, line: u32, ) { let mut request = server .options(path.as_ref()) .header("origin", origin.as_ref()); if !methods.is_empty() { let mut methods = methods .iter() .map(Method::as_str) .fold(String::new(), |mut s, method| { s.push_str(method); s.push_str(", "); s }); methods.pop(); methods.pop(); request = request.header("access-control-request-method", methods); } if !headers.is_empty() { let mut headers = headers.iter().fold(String::new(), |mut s, header| { s.push_str(header); s.push_str(", "); s }); headers.pop(); headers.pop(); request = request.header("access-control-request-headers", headers); } let response = request.send().await.unwrap(); if test_methods_and_headers { let mut all_all_here = true; if let Some(accepted_methods) = response .headers() .get("access-control-allow-methods") .and_then(|h| h.to_str().ok()) { let mut all_here = true; for expected_method in methods { if !accepted_methods.contains(expected_method.as_str()) { println!("NOT HERE!"); all_here = false; break; } } if !all_here { all_all_here = false; } } if let Some(accepted_headers) = response .headers() .get("access-control-allow-headers") .and_then(|h| h.to_str().ok()) { let mut all_here = true; for expected_header in headers { if !accepted_headers.contains(expected_header) { all_here = false; break; } } println!("All headers here"); if !all_here { all_all_here = false; } } assert_eq!(all_all_here, valid_expected, "On line {}", line); } else { test_cors_response(response, valid_expected, line).await; } } #[tokio::test] async fn options() { let server = ServerBuilder::from(get_extensions()).run().await; test_cors_options( &server, "/logo.svg", "ftp://foo.bar", &[Method::GET], &[], true, false, line!(), ) .await; test_cors_options( &server, "/api/test", "ftp://foo.bar", &[Method::PUT], &[], false, false, line!(), ) .await; test_cors_options( &server, "/api/test", "http://icelk.dev", &[], &[], false, false, line!(), ) .await; test_cors_options( &server, "/api/test", "https://icelk.dev", &[Method::GET, Method::PUT, Method::DELETE], &[], true, false, line!(), ) .await; test_cors_options( &server, "/api/test", "https://icelk.dev", &[Method::GET, Method::PUT, Method::DELETE, Method::POST], &[], false, true, line!(), ) .await; test_cors_options( &server, "/api/test", "https://icelk.dev", &[Method::GET, Method::PUT, Method::DELETE], &["content-type"], false, true, line!(), ) .await; test_cors_options( &server, "/", "https://icelk.dev", &[Method::GET], &[], false, false, line!(), ) .await; test_cors_options( &server, "/images", "https://example.org", &[Method::GET], &[], false, false, line!(), ) .await; test_cors_options( &server, "/images/", "https://example.org", &[Method::GET], &[], true, false, line!(), ) .await; test_cors_options( &server, "/images/my-funny-cat-pic.png", "https://example.org", &[Method::GET], &[], true, false, line!(), ) .await; test_cors_options( &server, "/images/my-funny-cat-pic.png", "https://kvarn.org", &[Method::GET], &[], false, false, line!(), ) .await; let max_age_response = server .options("/images/my-funny-cat-pic.png") .header("origin", "https://example.org") .header("access-control-request-method", "GET") .send() .await .unwrap(); assert_eq!( max_age_response .headers() .get("access-control-max-age") .unwrap() .to_str() .unwrap(), (60 * 60 * 24 * 365).to_string() ); }