use super::{aggregate::NumericMetricsAggregate, Aggregate, Direction, Event, EventStore, Split}; use crate::logger::MetricLogger; #[derive(Default)] pub(crate) struct LogEventStore { loggers_train: Vec>, loggers_valid: Vec>, aggregate_train: NumericMetricsAggregate, aggregate_valid: NumericMetricsAggregate, } impl EventStore for LogEventStore { fn add_event(&mut self, event: Event, split: Split) { match event { Event::MetricsUpdate(update) => match split { Split::Train => { update .entries .iter() .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) .for_each(|entry| { self.loggers_train .iter_mut() .for_each(|logger| logger.log(entry)); }); } Split::Valid => { update .entries .iter() .chain(update.entries_numeric.iter().map(|(entry, _value)| entry)) .for_each(|entry| { self.loggers_valid .iter_mut() .for_each(|logger| logger.log(entry)); }); } }, Event::EndEpoch(epoch) => match split { Split::Train => self .loggers_train .iter_mut() .for_each(|logger| logger.end_epoch(epoch)), Split::Valid => self .loggers_valid .iter_mut() .for_each(|logger| logger.end_epoch(epoch)), }, } } fn find_epoch( &mut self, name: &str, aggregate: Aggregate, direction: Direction, split: Split, ) -> Option { match split { Split::Train => { self.aggregate_train .find_epoch(name, aggregate, direction, &mut self.loggers_train) } Split::Valid => { self.aggregate_valid .find_epoch(name, aggregate, direction, &mut self.loggers_valid) } } } fn find_metric( &mut self, name: &str, epoch: usize, aggregate: Aggregate, split: Split, ) -> Option { match split { Split::Train => { self.aggregate_train .aggregate(name, epoch, aggregate, &mut self.loggers_train) } Split::Valid => { self.aggregate_valid .aggregate(name, epoch, aggregate, &mut self.loggers_valid) } } } } impl LogEventStore { /// Register a logger for training metrics. pub(crate) fn register_logger_train(&mut self, logger: ML) { self.loggers_train.push(Box::new(logger)); } /// Register a logger for validation metrics. pub(crate) fn register_logger_valid(&mut self, logger: ML) { self.loggers_valid.push(Box::new(logger)); } }