#![allow(clippy::bool_comparison)] /// This module contains helpers for managing lifetime of async functions. /// In particular, it wraps the following two things from tokio: /// * cancellation token to notify "async functions/actors" that they have to stop and /// * tokio::sync::mpsc::Sender object (when it is dropped, parent knows that all its "children" has stopped) /// We use a channel to wait until all children has died (likely, because cancellation_token had told them) /// See more at https://tokio.rs/tokio/topics/shutdown use anyhow::Result; #[derive(Default)] pub struct SharedState { children_names: std::collections::HashMap, } type SS = std::sync::Arc>; /// Usage: /// ```ignore /// async fn my_task(..., pill: Pill) { /// loop { /// tokio::select! { /// _ = pill.received() => { /// log::info!("Exiting..."); /// return; /// } /// ... /// } /// ``` pub struct Pill { pub cancellation_token: tokio_util::sync::CancellationToken, pub child_died_signal: tokio::sync::mpsc::Sender<()>, pub name: String, pub ss: SS, } impl Drop for Pill { fn drop(&mut self) { if self.name.is_empty() == false { let (alive, total) = { let mut ss = self.ss.lock().unwrap(); if let std::collections::hash_map::Entry::Occupied(mut entry) = ss.children_names.entry(self.name.clone()) { *entry.get_mut() -= 1; if *entry.get() == 0 { entry.remove(); } } let total: u64 = ss.children_names.values().sum(); let alive: Vec = ss.children_names.keys().take(3).cloned().collect(); (alive, total) }; if total as usize > alive.len() { log::info!("Child \"{}\" finished. Waiting for the remaining children: {alive:?} and {} other", self.name, total as usize - alive.len()); } else { log::info!("Child \"{}\" finished. Waiting for the remaining children: {alive:?}", self.name); } } } } impl Pill { pub fn received(&self) -> impl std::future::Future + '_ { self.cancellation_token.cancelled() } } /// This struct is created and used by parent. /// Example: /// ```ignore /// use digester::poison_pill::ChildrenStopper; /// let ct = tokio_util::sync::CancellationToken::new(); /// let mut ck = ChildrenStopper::from_existing_cancellation_token(ct); /// tokio::spawn(write_tetra_events(client, /// tetra_slot, /// ck.register_child("write_tetra_events"))); /// ck.stop_and_wait().await; /// ``` pub struct ChildrenStopper { cancellation_token: tokio_util::sync::CancellationToken, child_died_slot: tokio::sync::mpsc::Receiver<()>, child_died_signal: tokio::sync::mpsc::Sender<()>, ss: SS, } impl Default for ChildrenStopper { fn default() -> Self { Self::from_existing_cancellation_token(tokio_util::sync::CancellationToken::new()) } } impl ChildrenStopper { pub fn from_existing_cancellation_token(ct: tokio_util::sync::CancellationToken) -> Self { let (signal, slot) = tokio::sync::mpsc::channel(1); ChildrenStopper { cancellation_token: ct, child_died_slot: slot, child_died_signal: signal, ss: Default::default(), } } pub fn register_child(&mut self, name: &str) -> Pill { *self.ss.lock().unwrap().children_names.entry(name.into()).or_insert(0) += 1; Pill { cancellation_token: self.cancellation_token.clone(), child_died_signal: self.child_died_signal.clone(), name: name.into(), ss: self.ss.clone(), } } pub async fn wait(mut self) -> Result<()> { drop(self.child_died_signal); // drop our sender first because the recv() call otherwise sleeps forever. // let total: u64 = self.children_names.values().sum(); self.child_died_slot.recv().await; log::info!("All children finished"); Ok(()) } }