use async_trait::async_trait; use ocpi::{ types::{self, CredentialsToken, CsString}, Party, PartyStore, Result, Store, }; use std::{ collections::HashMap, sync::{ atomic::{AtomicU64, Ordering}, Arc, Mutex, }, }; use url::Url; #[derive(Clone)] pub struct TestParty { pub id: u64, pub name: types::CsString<100>, pub url: Url, pub token_they_use: types::CsString<64>, pub token_we_use: types::CsString<64>, pub version_details: types::VersionDetails, pub roles: Vec, } impl Party for TestParty { type Id = u64; fn id(&self) -> Self::Id { self.id } fn token_we_use(&self) -> types::CsString<64> { self.token_we_use.clone() } fn token_they_use(&self) -> types::CsString<64> { self.token_they_use.clone() } } #[derive(Default, Clone)] pub struct TestStore { pub id_counter: Arc, pub temp_token_gen: Arc, pub token_gen: Arc, pub reg_tokens: Arc>>, pub parties: Arc>>, } impl TestStore { pub fn create_reg_token(&self) -> CredentialsToken { let token = format!( "TEMPTOKEN-{}", self.temp_token_gen.fetch_add(1, Ordering::Relaxed) ); let mut lock = self.reg_tokens.lock().expect("Locking"); lock.insert(token.clone(), false); CredentialsToken::try_from(token).unwrap() } pub fn is_reg_token_used(&self, s: impl AsRef) -> Option { let lock = self.reg_tokens.lock().expect("Locking"); lock.get(s.as_ref()).copied() } pub fn by_token_we_use(&self, token: impl AsRef) -> Option { let lock = self.parties.lock().expect("Locking Parties"); lock.values() .find(|tp| tp.token_we_use() == token.as_ref()) .cloned() } pub fn by_token_they_use(&self, token: impl AsRef) -> Option { let lock = self.parties.lock().expect("Locking Parties"); lock.values() .find(|tp| tp.token_they_use() == token.as_ref()) .cloned() } } #[async_trait] impl Store for TestStore { type PartyModel = TestParty; type RegistrationModel = String; async fn get_authorized( &self, token: types::CredentialsToken, ) -> Result> { let mut lock = self.reg_tokens.lock().expect("Locking"); if let Some(b) = lock.get_mut(token.as_ref()) { *b = true; return Ok(ocpi::Authorized::Registration(token.to_string())); } drop(lock); if let Some(party) = self.by_token_they_use(&token) { return Ok(ocpi::Authorized::Party(party)); } Err(ocpi::Error::unauthorized("Invalid token")) } } #[async_trait] impl PartyStore for TestStore { async fn delete_party(&self, party_id: ::Id) -> Result<()> { let mut lock = self.parties.lock().expect("locking"); lock.remove(&party_id); Ok(()) } async fn save_new_party( &self, _temporary_model: Self::RegistrationModel, credentials: types::Credential, version_details: types::VersionDetails, ) -> Result { let id = self.id_counter.fetch_add(1, Ordering::Relaxed); let party = TestParty { id, name: credentials .roles .get(0) .expect("At least one role must be present") .business_details .name .clone(), url: credentials.url, token_we_use: credentials.token, token_they_use: self.generate_token(), roles: credentials.roles, version_details, }; let mut lock = self.parties.lock().expect("Locking parties"); lock.insert(id, party.clone()); Ok(party) } async fn update_party( &self, model: Self::PartyModel, credentials: types::Credential, details: types::VersionDetails, ) -> Result { let mut lock = self.parties.lock().expect("Locking parties"); let existing = lock .get_mut(&model.id) .ok_or_else(|| ocpi::Error::client_generic("Party not found"))?; existing.token_we_use = credentials.token; existing.name = credentials.roles[0].business_details.name.clone(); existing.url = credentials.url; existing.roles = credentials.roles; existing.version_details = details; existing.token_they_use = self.generate_token(); Ok(existing.clone()) } async fn get_our_roles(&self) -> Result> { Ok(vec![types::CredentialsRole { role: types::Role::Cpo, business_details: types::BusinessDetails { name: "TestStore details".parse().expect("Parsing name"), website: None, logo: None, }, party_id: "EXA".parse().expect("PartyId"), country_code: "se".parse().expect("CountryCode"), }]) } } impl TestStore { fn generate_token(&self) -> CsString<64> { format!("TOKEN-{}", self.token_gen.fetch_add(1, Ordering::Relaxed)) .parse::>() .expect("Token") } }