use std::{marker::PhantomData, num::NonZeroU32, sync::Arc}; use futures::Future; use crate::{Error, Message}; #[derive(Clone, Copy, Debug)] pub struct Config { pub queue_size: usize, pub queue_per_task: bool, pub ordered: bool, pub ordering_buffer_size: Option, pub task_count: u32, pub lazy_task_creation: bool, pub stream_per_message: bool, } impl Default for Config { fn default() -> Self { Self { queue_size: 4, queue_per_task: false, ordered: false, ordering_buffer_size: None, task_count: 1, lazy_task_creation: true, stream_per_message: false, } } } pub trait Builder: Send + Sync + 'static { type Context: 'static; fn config(&self, _stream_id: u32) -> Config { Default::default() } fn build( &self, stream_id: u32, _task_id: u32, ) -> impl Future> + Send + '_; } pub struct DefaultBuilder { config: Config, callback: C, _m: PhantomData<(M, H, F)>, } unsafe impl Sync for DefaultBuilder { } impl DefaultBuilder where M: Message, H: Sync + Send + 'static, F: Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { pub fn new(queue_size: usize, callback: C) -> Self { Self { config: Config { queue_size, ..Default::default() }, callback, _m: PhantomData, } } pub fn ordered(self, buf: Option) -> Self { let mut config = self.config; config.ordered = true; if let Some(buf) = buf { config.ordering_buffer_size = Some(NonZeroU32::new(buf).expect("Buffer length cannot be zero!")); } Self { config, callback: self.callback, _m: PhantomData, } } pub fn stream_per_message(self) -> Self { let mut config = self.config; config.stream_per_message = true; Self { config, callback: self.callback, _m: PhantomData, } } pub fn tasks(self, tasks: u32) -> Self { let mut config = self.config; config.task_count = tasks; Self { config, callback: self.callback, _m: PhantomData, } } } impl Builder for DefaultBuilder where M: Message, H: Sync + Send + 'static, F: Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { type Context = H; async fn build(&self, stream_id: u32, task_id: u32) -> Result { (self.callback)(stream_id, task_id).await } fn config(&self, _stream_id: u32) -> Config { self.config } } pub struct SharedBuilder { config: Config, stream_handlers: dashmap::DashMap>, callback: C, _m: PhantomData<(M, F)>, } unsafe impl Sync for SharedBuilder { } impl SharedBuilder where M: Message, H: Sync + Send + 'static, F: Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { pub fn new(queue_size: usize, task_count: u32, callback: C) -> Self { Self { config: Config { queue_size, task_count, ..Default::default() }, stream_handlers: Default::default(), callback, _m: PhantomData, } } pub fn stream_per_message(self) -> Self { let mut config = self.config; config.stream_per_message = true; Self { config, callback: self.callback, _m: PhantomData, stream_handlers: Default::default(), } } pub fn ordered(self, buf: Option) -> Self { let mut config = self.config; config.ordered = true; if let Some(buf) = buf { config.ordering_buffer_size = Some(NonZeroU32::new(buf).expect("Buffer length cannot be zero!")); } Self { config, stream_handlers: self.stream_handlers, callback: self.callback, _m: PhantomData, } } pub fn queue_per_task(self) -> Self { let mut config = self.config; config.queue_per_task = true; Self { config, stream_handlers: self.stream_handlers, callback: self.callback, _m: PhantomData, } } } impl Builder for SharedBuilder where M: Message, H: Sync + Send + 'static, F: Send + Future> + 'static, C: Sync + Send + Fn(u32, u32) -> F + 'static, { type Context = Arc; async fn build(&self, stream_id: u32, task_id: u32) -> Result { if self.stream_handlers.contains_key(&stream_id) { return Ok(self.stream_handlers.get(&stream_id).unwrap().clone()); } let val = match (self.callback)(stream_id, task_id).await { Ok(val) => Arc::new(val), Err(err) => return Err(err), }; self.stream_handlers.insert(stream_id, val.clone()); Ok(val) } fn config(&self, _stream_id: u32) -> Config { self.config } }