use std::{ collections::HashMap, fmt::Debug, marker::PhantomData, sync::{Arc, RwLock}, }; use axum::{ async_trait, extract::State, response::IntoResponse, routing::{get, post}, Extension, Json, Router, }; use axum_extra::extract::{cookie::Cookie, CookieJar}; use axum_session_manager::{SessionManage, SessionManagerLayer, UserData, UserState}; use http::StatusCode; use serde::{Deserialize, Serialize}; use thiserror::Error; use uuid::Uuid; const COOKIE_KEY: &str = "test-id"; #[tokio::main] async fn main() { let service = app(); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); axum::serve(listener, service).await.unwrap(); } fn app() -> Router { let session_pool = SessionPool::new(); let phantom = PhantomData::default(); let layer = SessionManagerLayer::new(session_pool.clone(), COOKIE_KEY, phantom); Router::new() .route("/", get(root)) .route("/login", post(login)) .route("/session/data", get(get_session_data)) .layer(layer) .with_state(session_pool) } async fn root() -> impl IntoResponse { "Hello" } async fn get_session_data( Extension(user_data): Extension>, ) -> Result { let user_state = user_data.0; match user_state { UserState::HaveSession(a) => Ok((StatusCode::OK, Json(a))), UserState::NoCookie => Err((StatusCode::UNAUTHORIZED, "no cookie")), UserState::NoSession => Err((StatusCode::UNAUTHORIZED, "you need login")), } } async fn login( State(session_pool): State, jar: CookieJar, Json(payload): Json, ) -> Result { let session_id = session_pool .add_session(payload) .await .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?; let cookie = Cookie::new(COOKIE_KEY, session_id); Ok((StatusCode::OK, jar.add(cookie))) } #[derive(Debug, Clone, Serialize, Deserialize)] struct Credential { id: i32, name: String, password: String, } #[derive(Debug, Clone)] struct UserPool { pool: Arc>>, } impl UserPool { fn new() -> Self { Self { pool: Arc::default(), } } } #[derive(Debug, Clone)] struct SessionPool { pool: Arc>>, } impl SessionPool { fn new() -> Self { Self { pool: Arc::default(), } } } #[derive(Debug, Clone, Error)] enum ServerError { #[error("unexpected error")] Unexpect, } #[async_trait] impl SessionManage for SessionPool { type SessionID = String; type UserInfo = Credential; type Error = ServerError; async fn add_session(&self, session_data: Credential) -> Result { let session_id = Uuid::new_v4().to_string(); { let _ = &self .pool .write() .map_err(|_e| ServerError::Unexpect)? .insert(session_id.clone(), session_data); } Ok(session_id) } async fn verify_session( &self, session_id: &str, ) -> Result, Self::Error> { let data = self.pool.read().map_err(|_e| ServerError::Unexpect)?.to_owned(); match data.get(session_id) { Some(user) => { let user = user.to_owned(); Ok(Some(user)) } None => Ok(None), } } async fn delete_session(&self, session_id: &str) -> Result<(), Self::Error> { self.pool .write() .map_err(|_e| ServerError::Unexpect)? .remove(session_id); Ok(()) } }