mod google_openid; use std::collections::HashSet; use axum::http::request::Parts; pub use google_openid::{GOOGLE_CLIENT, WITH_GOOGLE_AUTH}; use crate::*; use axum_login::{ AuthManagerLayer, AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, AuthzBackend, }; pub use openidconnect::{CsrfToken as OAuthCSRF, Nonce as OAuthNonce}; use password_auth::{generate_hash, verify_password}; use serde::{Deserialize, Serialize}; pub use tower_sessions::Session; use tower_sessions::{ session::{Id, Record}, session_store::{Error, Result}, Expiry, SessionManagerLayer, SessionStore, }; pub type UserId = Uuid; pub type AuthLayer = AuthManagerLayer; pub type Auth = AuthSession; pub type OAuthCode = String; #[derive(Table, Clone, Debug, Serialize, Deserialize)] pub struct User { pub id: Uuid, pub permissions: Vec, pub group: UserGroup, #[unique_column] pub username: Option, #[unique_column] pub email: Option, pub password_hash: Option, } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum UserGroup { Admin, Visitor, Custom(String), } #[derive(Clone, Debug)] pub enum Credentials { UsernamePassword { username: String, password: String }, EmailPassword { email: String, password: String }, GoogleOpenID { code: OAuthCode, nonce: OAuthNonce }, } pub type OAuthQuery = Query; #[derive(Debug, serde::Deserialize)] pub struct OAuthQueryParams { pub code: OAuthCode, pub state: OAuthCSRF, } #[allow(dead_code)] trait AuthBackend: AuthnBackend + SessionStore { } impl AuthBackend for Db {} pub const LOGIN_ROUTE: &str = "/auth/login"; pub const LOGOUT_ROUTE: &str = "/auth/logout"; pub const GOOGLE_LOGIN_ROUTE: &str = "/auth/google"; pub const GOOGLE_CALLBACK_ROUTE: &str = "/auth/google/callback"; pub fn init_auth_module() -> (AuthLayer, Router) { SessionRow::prepare_table(); User::prepare_table(); let mut session_layer = SessionManagerLayer::new(DB.cloned()) .with_name("prest_session") .with_same_site(tower_sessions::cookie::SameSite::Lax) .with_expiry(Expiry::OnInactivity(time::Duration::days(30))); if let Some(domain) = APP_CONFIG.check().domain.clone() { session_layer = session_layer.with_domain(domain); } let layer = AuthManagerLayerBuilder::new(DB.cloned(), session_layer).build(); let mut router = route(LOGIN_ROUTE, post(login)).route(LOGOUT_ROUTE, get(logout)); if *WITH_GOOGLE_AUTH { router = router .route(GOOGLE_LOGIN_ROUTE, get(init_google_oauth)) .route(GOOGLE_CALLBACK_ROUTE, get(google_oauth_callback)); } (layer, router) } impl User { pub fn from_email(email: String) -> Self { Self { id: Uuid::new_v4(), permissions: vec![], group: UserGroup::Visitor, username: None, email: Some(email), password_hash: None, } } pub fn from_username_password(username: String, password: String) -> Self { Self { id: Uuid::new_v4(), permissions: vec![], group: UserGroup::Visitor, username: Some(username), email: None, password_hash: Some(generate_hash(password)), } } pub fn from_email_password(email: String, password: String) -> Self { Self { id: Uuid::new_v4(), permissions: vec![], group: UserGroup::Visitor, username: None, email: Some(email), password_hash: Some(generate_hash(password)), } } pub fn is_admin(&self) -> bool { self.group == UserGroup::Admin } } #[derive(Debug, Default, serde::Deserialize)] struct AuthForm { username: Option, email: Option, password: String, signup: bool, next: Option, } async fn login(mut auth: Auth, Form(form): Form) -> impl IntoResponse { let AuthForm { username, email, password, signup, next, } = form; let user = if signup { let new = if let Some(username) = username { if User::find_by_username(&username).is_some() { return StatusCode::CONFLICT.into_response(); } User::from_username_password(username, password) } else if let Some(email) = email { if User::find_by_email(&email).is_some() { return StatusCode::CONFLICT.into_response(); } User::from_email_password(email, password) } else { return StatusCode::BAD_REQUEST.into_response(); }; let Ok(_) = new.save() else { return StatusCode::INTERNAL_SERVER_ERROR.into_response(); }; new } else { if let Some(username) = username { let credentials = Credentials::UsernamePassword { username, password }; let Ok(Some(user)) = auth.authenticate(credentials).await else { return StatusCode::UNAUTHORIZED.into_response(); }; user } else if let Some(email) = email { let credentials = Credentials::EmailPassword { email, password }; let Ok(Some(user)) = auth.authenticate(credentials).await else { return StatusCode::UNAUTHORIZED.into_response(); }; user } else { return StatusCode::BAD_REQUEST.into_response(); } }; if auth.login(&user).await.is_err() { #[cfg(debug_assertions)] return StatusCode::INTERNAL_SERVER_ERROR.into_response(); #[cfg(not(debug_assertions))] return StatusCode::UNAUTHORIZED.into_response(); } if let Some(next) = next { Redirect::to(&next).into_response() } else { Redirect::to("/").into_response() } } #[derive(Debug, serde::Deserialize)] struct NextUrl { next: Option, } const CSRF_KEY: &str = "oauth_csrf"; const NONCE_KEY: &str = "oauth_nonce"; const REDIRECT_KEY: &str = "after_auth_redirect"; async fn init_google_oauth( session: Session, Query(NextUrl { next }): Query, ) -> impl IntoResponse { let (authz_url, csrf_token, nonce) = GOOGLE_CLIENT.authz_request(); let ins1 = session.insert(NONCE_KEY, nonce).await; let ins2 = session.insert(CSRF_KEY, csrf_token).await; let ins3 = if let Some(next) = next { session.insert(REDIRECT_KEY, next).await } else { Ok(()) }; if ins1.is_err() || ins2.is_err() || ins3.is_err() { return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } Redirect::to(authz_url.as_str()).into_response() } async fn google_oauth_callback( session: Session, Query(query): OAuthQuery, mut auth: Auth, ) -> impl IntoResponse { let Ok(Some(initial_csrf)) = session.remove::(CSRF_KEY).await else { return StatusCode::UNAUTHORIZED.into_response(); }; let Ok(Some(nonce)) = session.remove::(NONCE_KEY).await else { return StatusCode::UNAUTHORIZED.into_response(); }; if initial_csrf.secret() != query.state.secret() { return StatusCode::UNAUTHORIZED.into_response(); } let credentials = Credentials::GoogleOpenID { code: query.code, nonce, }; let Ok(Some(user)) = auth.authenticate(credentials).await else { return StatusCode::UNAUTHORIZED.into_response(); }; if auth.login(&user).await.is_err() { #[cfg(debug_assertions)] return StatusCode::INTERNAL_SERVER_ERROR.into_response(); #[cfg(not(debug_assertions))] return StatusCode::UNAUTHORIZED.into_response(); } if let Ok(Some(next)) = session.remove::(REDIRECT_KEY).await { Redirect::to(&next).into_response() } else { Redirect::to("/").into_response() } } async fn logout(mut auth: Auth) -> impl IntoResponse { if let Some(_) = auth.user { auth.logout().await.unwrap(); } Redirect::to("/") } #[async_trait] impl FromRequestParts for User where S: Send + Sync, { type Rejection = StatusCode; async fn from_request_parts( parts: &mut Parts, _state: &S, ) -> std::result::Result { let Some(auth_session) = parts.extensions.get::().cloned() else { #[cfg(debug_assertions)] return Err(StatusCode::INTERNAL_SERVER_ERROR); #[cfg(not(debug_assertions))] return Err(StatusCode::UNAUTHORIZED); }; let Some(user) = auth_session.user else { return Err(StatusCode::UNAUTHORIZED); }; if parts.uri.path().starts_with("/admin/") && !user.is_admin() { return Err(StatusCode::UNAUTHORIZED); } Ok(user) } } impl AuthUser for User { type Id = Uuid; fn id(&self) -> Self::Id { self.id } fn session_auth_hash(&self) -> &[u8] { if let Some(password_hash) = &self.password_hash { password_hash.as_bytes() } else if let Some(email) = &self.email { email.as_bytes() } else if let Some(username) = &self.username { username.as_bytes() } else { self.id.as_bytes() } } } use thiserror::Error; #[derive(Error, Debug)] pub enum AuthError { // TODO } #[async_trait] impl AuthnBackend for Db { type User = User; type Credentials = Credentials; type Error = AuthError; async fn authenticate( &self, creds: Self::Credentials, ) -> std::result::Result, Self::Error> { match creds { Credentials::GoogleOpenID { code, nonce } => { if !*WITH_GOOGLE_AUTH { warn!("Attempted to authenticate with google credentials without google credentials!"); return Ok(None); // TODO an error here } let Ok(email) = GOOGLE_CLIENT.get_email(code, nonce).await else { return Ok(None); // TODO an error here }; match User::find_by_email(&email) { Some(user) => Ok(Some(user)), None => { let user = User::from_email(email); user.save().unwrap(); Ok(Some(user)) } } } Credentials::UsernamePassword { username, password } => { let Some(user) = User::find_by_username(&username) else { return Ok(None); // TODO an error here }; let Some(pw_hash) = &user.password_hash else { return Ok(None); // TODO an error here }; let Ok(()) = verify_password(password, pw_hash) else { return Ok(None); // TODO an error here }; Ok(Some(user)) } Credentials::EmailPassword { email, password } => { let Some(user) = User::find_by_email(&email) else { return Ok(None); // TODO an error here }; let Some(pw_hash) = &user.password_hash else { return Ok(None); // TODO an error here }; let Ok(()) = verify_password(password, pw_hash) else { return Ok(None); // TODO an error here }; Ok(Some(user)) } } } async fn get_user( &self, user_id: &axum_login::UserId, ) -> std::result::Result, Self::Error> { Ok(User::find_by_id(user_id)) } } pub type Permission = String; #[async_trait] impl AuthzBackend for Db { type Permission = Permission; async fn get_user_permissions( &self, user: &Self::User, ) -> std::result::Result, Self::Error> { Ok(user.permissions.iter().map(|s| s.to_owned()).collect()) } async fn get_group_permissions( &self, user: &Self::User, ) -> std::result::Result, Self::Error> { Ok(user.permissions.iter().map(|s| s.to_owned()).collect()) } async fn get_all_permissions( &self, user: &Self::User, ) -> std::result::Result, Self::Error> { Ok(user.permissions.iter().map(|s| s.to_owned()).collect()) } async fn has_perm( &self, user: &Self::User, perm: Self::Permission, ) -> std::result::Result { Ok(user.permissions.iter().find(|p| **p == perm).is_some()) } } #[derive(Table, Debug, Serialize, Deserialize)] pub struct SessionRow { pub id: Uuid, pub record: String, } #[async_trait] impl SessionStore for Db { async fn save(&self, record: &Record) -> Result<()> { let id = record.id.0; let record = match serde_json::to_string(record) { Ok(s) => s, Err(e) => return Err(Error::Encode(format!("{e}"))), }; match (SessionRow { id, record }).save() { Ok(_) => Ok(()), Err(e) => Err(Error::Backend(format!("Session save error: {e}"))), } } async fn load(&self, session_id: &Id) -> Result> { let search = SessionRow::find_by_id(&session_id.0); let Some(session_row) = search else { return Ok(None); }; match serde_json::from_str(&session_row.record) { Ok(record) => Ok(Some(record)), Err(e) => Err(Error::Decode(format!("Session load error: {e}"))), } } async fn delete(&self, session_id: &Id) -> Result<()> { match SessionRow::delete_by_key(&session_id.0) { Ok(_) => Ok(()), Err(e) => Err(Error::Backend(format!("Session deletion error: {e}"))), } } }