use std::sync::Arc; use authentication_module::{authenticate, AuthProvider, FakeAuthenticator, User}; use log::info; use thruster::{ context::typed_hyper_context::TypedHyperContext, hyper_server::HyperServer, m, middleware_fn, App, HyperRequest, MiddlewareNext, MiddlewareResult, ThrusterServer, }; use thruster_jab::{provide, JabDI}; use thruster_proc::context_state; type Ctx = TypedHyperContext; #[derive(Default)] #[context_state] pub struct State(Option>, Option); #[derive(Default)] struct ServerConfig { di: Arc, } fn generate_context(request: HyperRequest, state: &ServerConfig, _path: &str) -> Ctx { Ctx::new(request, State(Some(state.di.clone()), None)) } #[middleware_fn] async fn hello(mut context: Ctx, _next: MiddlewareNext) -> MiddlewareResult { let user: &Option = context.extra.get(); let user = user.as_ref().unwrap(); context.body(&format!("Hello, {} {}", user.first, user.last)); Ok(context) } #[tokio::main] async fn main() { env_logger::init(); info!("Starting server..."); let mut jab = JabDI::default(); provide!(jab, dyn AuthProvider, FakeAuthenticator::default()); let app = App::::create( generate_context, ServerConfig { di: Arc::new(jab) }, ) .get("/hello", m![authenticate, hello]); let server = HyperServer::new(app); server.build("0.0.0.0", 4321).await; } // This stection starts the custom middleware mod authentication_module { use async_trait::async_trait; use std::collections::HashMap; use std::sync::Arc; use thruster::{ context::typed_hyper_context::TypedHyperContext, errors::{ErrorSet, ThrusterError}, middleware::cookies::HasCookies, middleware_fn, ContextState, MiddlewareNext, MiddlewareResult, }; use thruster_jab::{fetch, JabDI}; use tokio::sync::Mutex; #[derive(Clone)] pub struct User { pub first: String, pub last: String, } #[async_trait] pub trait AuthProvider { async fn authenticate(&self, session_token: &str) -> Result; } pub struct FakeAuthenticator { fake_db: Mutex>, } impl Default for FakeAuthenticator { fn default() -> Self { Self { fake_db: Mutex::new(HashMap::from([ ( "lukes-secret-session-token".to_string(), User { first: "Luke".to_string(), last: "Skywalker".to_string(), }, ), ( "vaders-secret-session-token".to_string(), User { first: "Anakin".to_string(), last: "Skywalker".to_string(), }, ), ])), } } } #[async_trait] impl AuthProvider for FakeAuthenticator { async fn authenticate(&self, session_token: &str) -> Result { self.fake_db.lock().await.remove(session_token).ok_or(()) } } #[middleware_fn] pub async fn authenticate< S: Send + Sync + Default + ContextState>> + ContextState>, >( mut context: TypedHyperContext, next: MiddlewareNext>, ) -> MiddlewareResult> { let di: &Option> = context.extra.get(); let auth = fetch!(di.as_ref().unwrap(), dyn AuthProvider); let session_header = context .get_header("Session-Token") .pop() .unwrap_or_default(); let user = auth .authenticate(&session_header) .await .map_err(|_| ThrusterError::unauthorized_error(TypedHyperContext::default()))?; *context.extra.get_mut() = Some(user); next(context).await } }