//! Run with //! //! ```not_rust //! cargo run --example jwt_groups --features="axum" //! ``` //! use std::{collections::HashSet, sync::Arc}; use anyhow::Context; use axum::{response::IntoResponse, routing::get, Json, Router}; use composable_tower_http::{ authorize::{ header::bearer::DefaultBearerExtractor, jwt::{ jwk_set::{fetch::HttpJwkSetFetcher, rotating::RotatingJwkSetProvider}, DefaultJwtAuthorizerBuilder, Validation, }, }, extension::{ExtensionLayerExt, ModificationLayerExt}, extract::Extracted, modify::Modifier, }; use http::StatusCode; use reqwest::Client; use serde::{Deserialize, Serialize}; #[path = "../util/util.rs"] mod util; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Claims { pub email_verified: bool, pub name: String, pub preferred_username: String, pub given_name: String, pub family_name: String, pub email: String, pub groups: Vec<String>, } async fn claims(Extracted(claims): Extracted<Claims>) -> impl IntoResponse { Json(claims) } #[tokio::main] async fn main() -> anyhow::Result<()> { util::init("jwt_groups")?; let jwks_uri = std::env::var("JWKS_URI").unwrap_or_else(|_| { String::from("https://keycloak.com/realms/master/protocol/openid-connect/certs") }); let iss = std::env::var("ISS").unwrap_or_else(|_| String::from("https://keycloak.com/realms/master")); tracing::info!(%jwks_uri, %iss); let auth_layer = DefaultJwtAuthorizerBuilder::new( DefaultBearerExtractor::new(), RotatingJwkSetProvider::new(30, HttpJwkSetFetcher::new(jwks_uri, Client::new())) .await .context("Failed to create jwk set provider")?, Validation::new().aud(&["account"]).iss(&[iss]), ) .build::<Claims>() .extension_layer(); // These layers will look into the extracted claims in the request extensions and perform a modification removing the old claims and inserting modified claims. let admins: HashSet<String> = ["/admins"].into_iter().map(Into::into).collect(); let admins_modify_layer = GroupsValidator::new(admins).modification_layer::<Claims>(); let super_admins: HashSet<String> = ["/super-admins"].into_iter().map(Into::into).collect(); let super_admins_modify_layer = GroupsValidator::new(super_admins).modification_layer::<Claims>(); let app = Router::new() // curl -H "Authorization: Bearer <token>" localhost:5000/super-admins .route( "/super-admins", get(claims).layer(super_admins_modify_layer), ) // curl -H "Authorization: Bearer <token>" localhost:5000/admins .route("/admins", get(claims).layer(admins_modify_layer)) // curl -H "Authorization: Bearer <token>" localhost:5000 .route("/", get(claims)) // The auth layer will extract the claims from the request and insert them into the request extensions. .layer(auth_layer) .layer(util::trace_layer()); util::serve(app).await } #[derive(Debug, Clone)] struct GroupsValidator { groups: Arc<HashSet<String>>, } impl GroupsValidator { fn new(groups: HashSet<String>) -> Self { Self { groups: Arc::new(groups), } } } impl Modifier<Claims> for GroupsValidator { type Modified = Claims; type Error = GroupsValidationError; async fn modify(&self, claims: Claims) -> Result<Claims, Self::Error> { if claims .groups .iter() .any(|group| self.groups.contains(group)) { return Ok(claims); }; Err(GroupsValidationError) } } #[derive(Debug, thiserror::Error)] #[error("Not in groups")] struct GroupsValidationError; impl IntoResponse for GroupsValidationError { fn into_response(self) -> axum::response::Response { (StatusCode::UNAUTHORIZED, "Not in groups").into_response() } }