use bson::doc; use bson::oid::ObjectId; use mongodb::Client; use mongodb_lock::Mutex; use mongodb_lock::RwLock; use serde::{Deserialize, Serialize}; use std::net::Ipv4Addr; use std::process::{Command, Stdio}; use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, AtomicU16}; use std::sync::Arc; use std::time::Duration; use tokio::task; use tokio::time::sleep; use tracing::info; struct MongodbClient { client: mongodb::Client, _server: Arc, } impl AsRef for MongodbClient { fn as_ref(&self) -> &Client { &self.client } } pub struct Mongodb { name: String, port: u16, } impl Mongodb { async fn new() -> Self { static PORT: AtomicU16 = AtomicU16::new(27021); // atlas deployments setup tester2 --type local --port 8082 --force let port = PORT.fetch_add(1, Ordering::SeqCst); let name = uuid::Uuid::new_v4().to_string(); info!("port: {port}"); info!("name: {name}"); let _mongodb = Command::new("atlas") .args([ "deployments", "setup", &name, "--type", "local", "--port", &port.to_string(), "--force", ]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .output() .unwrap(); info!("setup deployment"); sleep(Duration::from_secs(3)).await; Mongodb { name, port } } fn client(this: Arc) -> MongodbClient { let host = mongodb::options::ServerAddress::Tcp { host: Ipv4Addr::LOCALHOST.to_string(), port: Some(this.port), }; let opts = mongodb::options::ClientOptions::builder() .hosts(vec![host]) .direct_connection(true) .build(); let client = mongodb::Client::with_options(opts).unwrap(); MongodbClient { client, _server: this, } } } impl Drop for Mongodb { fn drop(&mut self) { info!("deleting deployment"); let _mongodb = Command::new("atlas") .args(["deployments", "delete", &self.name, "--force"]) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .output() .unwrap(); info!("deleted"); } } #[tokio::test] async fn single() { use mongodb_lock::*; #[derive(Clone, Serialize, Deserialize)] struct MyDocument { _id: ObjectId, x: i32, } let _guard = tracing::subscriber::set_default( tracing_subscriber::fmt::Subscriber::builder() .with_test_writer() .finish(), ); let mongodb = Arc::new(Mongodb::new().await); let client = Mongodb::client(mongodb); let db = client.as_ref().database("basic"); let docs = db.collection::("docs"); let lock = Arc::new(Mutex::new(&db, "locks").await.unwrap()); let one = MyDocument { _id: ObjectId::new(), x: 1, }; let two = MyDocument { _id: ObjectId::new(), x: 1, }; let three = MyDocument { _id: ObjectId::new(), x: 1, }; docs.insert_many(vec![one.clone(), two.clone(), three.clone()]) .await .unwrap(); let one_id = one._id; let two_id = two._id; let clock = lock.clone(); let cdocs = docs.clone(); let first = task::spawn(async move { let _guard = clock.lock_default([one_id, two_id]).await.unwrap(); let a = cdocs .find_one(doc! { "_id": one_id }) .await .unwrap() .unwrap(); let b = cdocs .find_one(doc! { "_id": two_id }) .await .unwrap() .unwrap(); cdocs .update_many( doc! { "_id": { "$in": [one_id,two_id] }}, doc! { "$set": { "x": a.x + b.x } }, ) .await .unwrap(); }); let two_id = two._id; let three_id = three._id; let clock = lock.clone(); let cdocs = docs.clone(); let second = task::spawn(async move { let _guard = clock.lock_default([two_id, three_id]).await.unwrap(); let a = cdocs .find_one(doc! { "_id": two_id }) .await .unwrap() .unwrap(); let b = cdocs .find_one(doc! { "_id": three_id }) .await .unwrap() .unwrap(); cdocs .update_many( doc! { "_id": { "$in": [two_id,three_id] } }, doc! { "$set": { "x": a.x + b.x } }, ) .await .unwrap(); }); first.await.unwrap(); second.await.unwrap(); let a = docs .find_one(doc! { "_id": one_id }) .await .unwrap() .unwrap() .x; let b = docs .find_one(doc! { "_id": two_id }) .await .unwrap() .unwrap() .x; let c = docs .find_one(doc! { "_id": three_id }) .await .unwrap() .unwrap() .x; assert!((a == 2 && b == 3 && c == 3) || (a == 3 && b == 3 && c == 2)); } #[tokio::test] async fn adder() { #[derive(Debug, Serialize, Deserialize)] struct Number { _id: ObjectId, x: i32, } const N: i32 = 10; const A: i32 = 0; const B: i32 = 10; static CHECK_ONE: AtomicBool = AtomicBool::new(false); static CHECK_TWO: AtomicBool = AtomicBool::new(false); let _guard = tracing::subscriber::set_default( tracing_subscriber::fmt::Subscriber::builder() .with_test_writer() .finish(), ); let mongodb = Arc::new(Mongodb::new().await); let client = Mongodb::client(mongodb); let db = client.as_ref().database("adder"); let lock = Arc::new(Mutex::new(&db, "locks").await.unwrap()); let cola = db.collection::("first"); let colb = db.collection::("second"); let ida1 = ObjectId::new(); let ida2 = ObjectId::new(); cola.insert_one(Number { _id: ida1, x: A }).await.unwrap(); cola.insert_one(Number { _id: ida2, x: B }).await.unwrap(); let idb1 = ObjectId::new(); let idb2 = ObjectId::new(); colb.insert_one(Number { _id: idb1, x: A }).await.unwrap(); colb.insert_one(Number { _id: idb2, x: B }).await.unwrap(); let tasks = (0..N) .flat_map(|_| { let (clock, ccola, ccolb, cida1, cidb1) = (lock.clone(), cola.clone(), colb.clone(), ida1, idb1); let one = task::spawn(async move { let guard = clock.lock_default([cida1, cidb1]).await.unwrap(); assert!(!CHECK_ONE.swap(true, Ordering::SeqCst)); let num = ccola .find_one(doc! { "_id": cida1 }) .await .unwrap() .unwrap(); ccolb .update_one(doc! { "_id": cidb1 }, doc! { "$set": { "x": num.x + 1 } }) .await .unwrap(); let num = ccolb .find_one(doc! { "_id": cidb1 }) .await .unwrap() .unwrap(); ccola .update_one(doc! { "_id": cida1 }, doc! { "$set": { "x": num.x + 1 } }) .await .unwrap(); assert!(CHECK_ONE.swap(false, Ordering::SeqCst)); drop(guard); }); let (clock, ccola, ccolb, cida2, cidb2) = (lock.clone(), cola.clone(), colb.clone(), ida2, idb2); let two = task::spawn(async move { let guard = clock.lock_default([cida2, cidb2]).await.unwrap(); assert!(!CHECK_TWO.swap(true, Ordering::SeqCst)); let num = ccola .find_one(doc! { "_id": cida2 }) .await .unwrap() .unwrap(); ccolb .update_one(doc! { "_id": cidb2 }, doc! { "$set": { "x": num.x + 1 } }) .await .unwrap(); let num = ccolb .find_one(doc! { "_id": cidb2 }) .await .unwrap() .unwrap(); ccola .update_one(doc! { "_id": cida2 }, doc! { "$set": { "x": num.x + 1 } }) .await .unwrap(); assert!(CHECK_TWO.swap(false, Ordering::SeqCst)); drop(guard); }); [one, two] }) .collect::>(); for task in tasks { task.await.unwrap(); } const T: i32 = N * 2; let numa1 = cola.find_one(doc! { "_id": ida1 }).await.unwrap().unwrap(); let numb1 = colb.find_one(doc! { "_id": idb1 }).await.unwrap().unwrap(); info!("numa1: {numa1:?}"); info!("numb1: {numb1:?}"); assert!( (numa1.x == T - A + 1 && numb1.x == T + A) || (numb1.x == T + A - 1 && numa1.x == T + A) ); let numa2 = cola.find_one(doc! { "_id": ida2 }).await.unwrap().unwrap(); let numb2 = colb.find_one(doc! { "_id": idb2 }).await.unwrap().unwrap(); info!("numa2: {numa2:?}"); info!("numb2: {numb2:?}"); assert!( (numa2.x == T + B - 1 && numb2.x == T + B) || (numb2.x == T + B - 1 && numa2.x == T + B) ); } #[tokio::test] async fn reader() { #[derive(Debug, Serialize, Deserialize)] struct Number { _id: ObjectId, x: i32, } const READS: usize = 10; const WRITE: usize = 10; let _guard = tracing::subscriber::set_default( tracing_subscriber::fmt::Subscriber::builder() .with_test_writer() .finish(), ); let mongodb = Arc::new(Mongodb::new().await); let client = Mongodb::client(mongodb); let db = client.as_ref().database("adder"); let lock = Arc::new(RwLock::new(&db, "locks").await.unwrap()); let col = db.collection::("first"); let id = ObjectId::new(); col.insert_one(Number { _id: id, x: 0 }).await.unwrap(); let reads = (0..READS) .map(|_| { let clock = lock.clone(); let ccol = col.clone(); let cid = id.clone(); task::spawn(async move { let _guard = clock.read_default().await.unwrap(); let a = ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x; assert_eq!( ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x, a ); assert_eq!( ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x, a ); assert_eq!( ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x, a ); }) }) .collect::>(); let writes = (0..WRITE) .map(|_| { let clock = lock.clone(); let ccol = col.clone(); let cid = id.clone(); task::spawn(async move { let _guard = clock.write_default().await.unwrap(); let a = ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x; ccol.update_one(doc! {"_id": cid}, doc! { "$inc": { "x": 1i32 } }) .await .unwrap(); assert_eq!( ccol.find_one(doc! { "_id": cid }).await.unwrap().unwrap().x, a + 1 ); }) }) .collect::>(); for read in reads { read.await.unwrap(); } for write in writes { write.await.unwrap(); } }