use std::sync::Arc; use axum::extract::State; use axum::http::{self, HeaderMap, HeaderValue, Request}; use axum::middleware::Next; use axum::response::IntoResponse; use lib::model::{ModelId, ValidShardedId}; use lib::types::ProjectId; use tracing::error; use crate::auth::SecretApiKey; use crate::errors::ApiError; use crate::AppState; fn get_auth_key( header_map: &HeaderMap, ) -> Result { let auth_header = header_map .get(http::header::AUTHORIZATION) .and_then(|header| header.to_str().ok()); let auth_header = if let Some(auth_header) = auth_header { auth_header } else { return Err(ApiError::Unauthorized); }; if auth_header.is_empty() { return Err(ApiError::Unauthorized); } match auth_header.split_once(' ') { | Some((name, content)) if name == "Bearer" => Ok(content.to_string()), | _ => { Err(ApiError::BadRequest( "Authentication header is malformed, please use \ `Authorization: Bearer sk_...`" .to_owned(), )) } } } /// Ensures that the caller is authenticated with an admin key AND acting on /// behalf of a project. The `ProjectId` is then injected in the request /// extensions. pub async fn admin_only_auth_for_project( State(state): State>, mut req: Request, next: Next, ) -> Result { let auth_key = get_auth_key(req.headers())?; let admin_keys = &state.config.api.admin_api_keys; if admin_keys.contains(&auth_key) { let project = extract_project_from_request(&req)?; req.extensions_mut().insert(project.clone()); Ok(next.run(req).await) } else { Err(ApiError::Forbidden) } } /// Ensures that the caller is authenticated with an admin key. No project is /// required. Handlers using this middleware shouldn't rely on a `ProjectId` /// being set in the request extensions. pub async fn admin_only_auth( State(state): State>, req: Request, next: Next, ) -> Result { let auth_key = get_auth_key(req.headers())?; let admin_keys = &state.config.api.admin_api_keys; if admin_keys.contains(&auth_key) { Ok(next.run(req).await) } else { Err(ApiError::Forbidden) } } fn extract_project_from_request( req: &Request, ) -> Result, ApiError> { // This is an admin user which is acting on behalf of some project. const ON_BEHALF_OF_HEADER_NAME: &str = "X-On-Behalf-Of"; if let Some(project) = req.headers().get(ON_BEHALF_OF_HEADER_NAME) { let project = project.to_str().map_err(|_| { ApiError::BadRequest(format!( "{ON_BEHALF_OF_HEADER_NAME} header is not a valid UTF-8 string" )) })?; let validated_project = ProjectId::from(project.to_owned()) .validated() .map_err(|_| { ApiError::BadRequest(format!( "Invalid project id in {ON_BEHALF_OF_HEADER_NAME} header" )) }); return validated_project; } error!("Admin user didn't set {} header", ON_BEHALF_OF_HEADER_NAME); Err(ApiError::BadRequest( "Super privilege header(s) missing!".to_owned(), )) } pub async fn auth( State(state): State>, mut req: Request, next: Next, ) -> Result { let auth_key = get_auth_key(req.headers())?; let admin_keys = &state.config.api.admin_api_keys; if admin_keys.contains(&auth_key) { let project = extract_project_from_request(&req)?; req.extensions_mut().insert(project.clone()); return Ok(next.run(req).await); } let Ok(user_provided_secret) = auth_key.to_string().parse::() else { return Err(ApiError::Unauthorized); }; let project = state .authenicator .authenticate(&user_provided_secret) .await?; req.extensions_mut().insert(project.clone()); let mut resp = next.run(req).await; // Inject project_id in the response extensions as well. resp.extensions_mut().insert(project); Ok(resp) }