use axum::http::StatusCode; use axum::{async_trait, extract::FromRequestParts, http::request::Parts}; use di::{KeyedRef, KeyedRefMut, ServiceProvider}; use std::any::type_name; use std::convert::Infallible; /// Represents a container for an optional, injected, keyed service. #[derive(Clone, Debug)] pub struct TryInjectWithKey(pub Option>); /// Represents a container for a required, injected, keyed service. #[derive(Clone, Debug)] pub struct InjectWithKey(pub KeyedRef); /// Represents a container for an optional, mutable, injected, keyed service. #[derive(Clone, Debug)] pub struct TryInjectWithKeyMut(pub Option>); /// Represents a container for a required, mutable, injected, keyed service. #[derive(Clone, Debug)] pub struct InjectWithKeyMut(pub KeyedRefMut); /// Represents a container for a collection of injected, keyed services. #[derive(Clone, Debug)] pub struct InjectAllWithKey(pub Vec>); /// Represents a container for a collection of mutable, injected, keyed services. #[derive(Clone, Debug)] pub struct InjectAllWithKeyMut(pub Vec>); #[inline] fn unregistered_type_with_key() -> String { format!( "No service for type '{}' with the key '{}' has been registered.", type_name::(), type_name::() ) } #[async_trait] impl FromRequestParts for TryInjectWithKey where TSvc: ?Sized + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_by_key::())) } else { Ok(Self(None)) } } } #[async_trait] impl FromRequestParts for InjectWithKey where TSvc: ?Sized + 'static, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { if let Some(service) = provider.get_by_key::() { return Ok(Self(service)); } } Err(( StatusCode::INTERNAL_SERVER_ERROR, unregistered_type_with_key::(), )) } } #[async_trait] impl FromRequestParts for TryInjectWithKeyMut where TSvc: ?Sized + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_by_key_mut::())) } else { Ok(Self(None)) } } } #[async_trait] impl FromRequestParts for InjectWithKeyMut where TSvc: ?Sized + 'static, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { if let Some(service) = provider.get_by_key_mut::() { return Ok(Self(service)); } } Err(( StatusCode::INTERNAL_SERVER_ERROR, unregistered_type_with_key::(), )) } } #[async_trait] impl FromRequestParts for InjectAllWithKey where TSvc: ?Sized + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_all_by_key::().collect())) } else { Ok(Self(Vec::with_capacity(0))) } } } #[async_trait] impl FromRequestParts for InjectAllWithKeyMut where TSvc: ?Sized + 'static, S: Send + Sync, { type Rejection = Infallible; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { if let Some(provider) = parts.extensions.get::() { Ok(Self(provider.get_all_by_key_mut::().collect())) } else { Ok(Self(Vec::with_capacity(0))) } } } #[cfg(test)] mod tests { use super::*; use crate::{RouterServiceProviderExtensions, TestClient}; use axum::{ routing::{get, post}, Router, extract::State, }; use di::{injectable, Injectable, ServiceCollection}; use http::StatusCode; mod key { pub struct Basic; pub struct Advanced; } #[tokio::test] async fn request_should_fail_with_500_for_unregistered_service_with_key() { // arrange struct Service; impl Service { fn do_work(&self) -> String { "Test".into() } } async fn handler(InjectWithKey(service): InjectWithKey) -> String { service.do_work() } let app = Router::new() .route("/test", get(handler)) .with_provider(ServiceProvider::default()); let client = TestClient::new(app); // act let response = client.get("/test").send().await; // assert assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); } #[tokio::test] async fn try_inject_with_key_into_handler() { // arrange #[injectable] struct Service; async fn handler( TryInjectWithKey(_service): TryInjectWithKey, ) -> StatusCode { StatusCode::NO_CONTENT } let app = Router::new() .route("/test", post(handler)) .with_provider(ServiceProvider::default()); let client = TestClient::new(app); // act let response = client.post("/test").send().await; // assert assert_eq!(response.status(), StatusCode::NO_CONTENT); } #[tokio::test] async fn inject_with_key_into_handler() { // arrange trait Service: Send + Sync { fn do_work(&self) -> String; } #[injectable(Service)] struct ServiceImpl; impl Service for ServiceImpl { fn do_work(&self) -> String { "Test".into() } } async fn handler(InjectWithKey(service): InjectWithKey) -> String { service.do_work() } let provider = ServiceCollection::new() .add(ServiceImpl::scoped().with_key::()) .build_provider() .unwrap(); let app = Router::new() .route("/test", get(handler)) .with_provider(provider); let client = TestClient::new(app); // act let response = client.get("/test").send().await; let text = response.text().await; // assert assert_eq!(&text, "Test"); } #[tokio::test] async fn inject_all_with_key_into_handler() { // arrange trait Thing: Send + Sync {} #[injectable(Thing)] struct Thing1; #[injectable(Thing)] struct Thing2; #[injectable(Thing)] struct Thing3; impl Thing for Thing1 {} impl Thing for Thing2 {} impl Thing for Thing3 {} async fn handler( InjectAllWithKey(things): InjectAllWithKey, ) -> String { things.len().to_string() } let provider = ServiceCollection::new() .try_add_to_all(Thing1::scoped().with_key::()) .try_add_to_all(Thing2::scoped().with_key::()) .try_add_to_all(Thing3::scoped().with_key::()) .build_provider() .unwrap(); let app = Router::new() .route("/test", get(handler)) .with_provider(provider); let client = TestClient::new(app); // act let response = client.get("/test").send().await; let text = response.text().await; // assert assert_eq!(&text, "2"); } #[tokio::test] async fn inject_with_key_and_state_into_handler() { // arrange trait Service: Send + Sync { fn do_work(&self) -> String; } #[injectable(Service)] struct ServiceImpl; impl Service for ServiceImpl { fn do_work(&self) -> String { "Test".into() } } #[derive(Clone)] struct AppState; async fn handler( InjectWithKey(service): InjectWithKey, State(_state): State) -> String { service.do_work() } let provider = ServiceCollection::new() .add(ServiceImpl::scoped().with_key::()) .build_provider() .unwrap(); let app = Router::new() .route("/test", get(handler)) .with_state(AppState) .with_provider(provider); let client = TestClient::new(app); // act let response = client.get("/test").send().await; let text = response.text().await; // assert assert_eq!(&text, "Test"); } }