#[cfg(feature = "timing")] use crate::timing::Timing; use crate::{ binary_classifier::{BinaryClassifier, BinaryClassifierTrainOutput}, compute_bin_stats::{BinStats, BinStatsEntry}, compute_binned_features::{ compute_binned_features_column_major, compute_binned_features_row_major, }, compute_binning_instructions::compute_binning_instructions, compute_feature_importances::compute_feature_importances, multiclass_classifier::{MulticlassClassifier, MulticlassClassifierTrainOutput}, pool::Pool, regressor::{Regressor, RegressorTrainOutput}, train_tree::{ train_tree, TrainBranchNode, TrainBranchSplit, TrainBranchSplitContinuous, TrainBranchSplitDiscrete, TrainLeafNode, TrainNode, TrainTree, TrainTreeOptions, }, BinnedFeaturesLayout, BranchNode, BranchSplit, BranchSplitContinuous, BranchSplitDiscrete, LeafNode, Node, Progress, TrainOptions, TrainProgressEvent, Tree, }; use ndarray::prelude::*; use num::ToPrimitive; use rayon::prelude::*; use tangram_progress_counter::ProgressCounter; use tangram_table::prelude::*; /// This enum is used by the common `train` function below to customize the training code slightly for each task. #[derive(Clone, Copy, Debug)] pub enum Task { Regression, BinaryClassification, MulticlassClassification { n_classes: usize }, } /// This is the return type of the common `train` function. #[derive(Debug)] pub enum TrainOutput { Regressor(RegressorTrainOutput), BinaryClassifier(BinaryClassifierTrainOutput), MulticlassClassifier(MulticlassClassifierTrainOutput), } /// To avoid code duplication, this shared `train` function is called by `Regressor::train`, `BinaryClassifier::train`, and `MulticlassClassifier::train`. pub fn train( task: Task, features: TableView, labels: TableColumnView, train_options: &TrainOptions, progress: Progress, ) -> TrainOutput { #[cfg(feature = "timing")] let timing = Timing::new(); // If early stopping is enabled, split the features and labels into train and early stopping sets. let early_stopping_enabled = train_options.early_stopping_options.is_some(); let ( features_train, labels_train, features_early_stopping, labels_early_stopping, mut early_stopping_monitor, ) = if let Some(early_stopping_options) = &train_options.early_stopping_options { let (features_train, labels_train, features_early_stopping, labels_early_stopping) = train_early_stopping_split( features, labels, early_stopping_options.early_stopping_fraction, ); let early_stopping_monitor = EarlyStoppingMonitor::new( early_stopping_options.min_decrease_in_loss_for_significant_change, early_stopping_options.n_rounds_without_improvement_to_stop, ); ( features_train, labels_train, Some(features_early_stopping), Some(labels_early_stopping), Some(early_stopping_monitor), ) } else { (features, labels, None, None, None) }; let n_features = features_train.ncols(); let n_examples_train = features_train.nrows(); // Determine how to bin each feature. #[cfg(feature = "timing")] let start = std::time::Instant::now(); let binning_instructions = compute_binning_instructions(&features_train, train_options); #[cfg(feature = "timing")] timing.compute_binning_instructions.inc(start.elapsed()); // Use the binning instructions from the previous step to compute the binned features. let binned_features_layout = train_options.binned_features_layout; let progress_counter = ProgressCounter::new(features_train.nrows().to_u64().unwrap()); (progress.handle_progress_event)(TrainProgressEvent::Initialize(progress_counter.clone())); #[cfg(feature = "timing")] let start = std::time::Instant::now(); let compute_binned_features_column_major_output = compute_binned_features_column_major( &features_train, &binning_instructions, train_options, &|| progress_counter.inc(1), ); let features_train = features_train .view_columns(&compute_binned_features_column_major_output.used_feature_indexes); let features_early_stopping = features_early_stopping .as_ref() .map(|features_early_stopping| { features_early_stopping .view_columns(&compute_binned_features_column_major_output.used_feature_indexes) .to_rows() }); let used_features_binning_instructions = compute_binned_features_column_major_output .used_feature_indexes .iter() .map(|original_feature_index| binning_instructions[*original_feature_index].clone()) .collect::>(); let binned_features_row_major = if let BinnedFeaturesLayout::RowMajor = train_options.binned_features_layout { Some(compute_binned_features_row_major( &features_train, &used_features_binning_instructions, &|| progress_counter.inc(1), )) } else { None }; #[cfg(feature = "timing")] timing.compute_binned_features.inc(start.elapsed()); // Regression and binary classification train one tree per round. Multiclass classification trains one tree per class per round. let n_trees_per_round = match task { Task::Regression => 1, Task::BinaryClassification => 1, Task::MulticlassClassification { n_classes } => n_classes, }; // The mean square error loss used in regression has a constant second derivative, so there is no need to use hessians for regression tasks. let hessians_are_constant = match task { Task::Regression => true, Task::BinaryClassification => false, Task::MulticlassClassification { .. } => false, }; // Compute the biases. A tree model's prediction will be a bias plus the sum of the outputs of each tree. The bias will produce the baseline prediction. let biases = match task { // For regression, the bias is the mean of the labels. Task::Regression => { let labels_train = labels_train.as_number().unwrap(); let labels_train = labels_train.as_slice().into(); crate::regressor::compute_biases(labels_train) } // For binary classification, the bias is the log of the ratio of positive examples to negative examples in the training set, so the baseline prediction is the majority class. Task::BinaryClassification => { let labels_train = labels_train.as_enum().unwrap(); let labels_train = labels_train.as_slice().into(); crate::binary_classifier::compute_biases(labels_train) } // For multiclass classification the biases are the logs of each class's proporation in the training set, so the baseline prediction is the majority class. Task::MulticlassClassification { .. } => { let labels_train = labels_train.as_enum().unwrap(); let labels_train = labels_train.as_slice().into(); crate::multiclass_classifier::compute_biases(labels_train, n_trees_per_round) } }; // Pre-allocate memory to be used in training. let mut predictions = unsafe { Array::uninit((n_examples_train, n_trees_per_round).f()).assume_init() }; let mut gradients = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut hessians = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut gradients_ordered_buffer = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut hessians_ordered_buffer = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut examples_index = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut examples_index_left_buffer = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut examples_index_right_buffer = unsafe { Array::uninit(n_examples_train).assume_init() }; let mut predictions_early_stopping = if early_stopping_enabled { let mut predictions_early_stopping = unsafe { Array::uninit(( n_trees_per_round, labels_early_stopping.as_ref().unwrap().len(), )) .assume_init() }; for mut predictions in predictions_early_stopping.axis_iter_mut(Axis(1)) { predictions.assign(&biases); } Some(predictions_early_stopping) } else { None }; let binning_instructions_for_pool = used_features_binning_instructions.clone(); let bin_stats_pool = match binned_features_layout { BinnedFeaturesLayout::ColumnMajor => Pool::new( train_options.max_leaf_nodes, Box::new(move || { BinStats::ColumnMajor( binning_instructions_for_pool .iter() .map(|binning_instructions| { vec![BinStatsEntry::default(); binning_instructions.n_bins()] }) .collect(), ) }), ), BinnedFeaturesLayout::RowMajor => Pool::new( train_options.max_leaf_nodes, Box::new(move || { BinStats::RowMajor( binning_instructions_for_pool .iter() .flat_map(|binning_instructions| { vec![BinStatsEntry::default(); binning_instructions.n_bins()] }) .collect(), ) }), ), }; // This is the total number of rounds that have been trained thus far. let mut n_rounds_trained = 0; // These are the trees in round-major order. After training this will be converted to an array of shape (n_rounds, n_trees_per_round). let mut trees: Vec = Vec::new(); // Collect the loss on the training dataset for each round if enabled. let mut losses: Option> = if train_options.compute_losses { Some(Vec::new()) } else { None }; // Before the first round, fill the predictions with the biases, which are the baseline predictions. for mut predictions in predictions.axis_iter_mut(Axis(0)) { predictions.assign(&biases) } (progress.handle_progress_event)(TrainProgressEvent::InitializeDone); // Train rounds of trees until we hit max_rounds or the early stopping monitor indicates we should stop early. let round_counter = ProgressCounter::new(train_options.max_rounds.to_u64().unwrap()); (progress.handle_progress_event)(TrainProgressEvent::Train(round_counter.clone())); for _ in 0..train_options.max_rounds { round_counter.inc(1); // Train n_trees_per_round trees. let mut trees_for_round = Vec::with_capacity(n_trees_per_round); for tree_per_round_index in 0..n_trees_per_round { // Before training the next tree, we need to determine what value for each example we would like the tree to learn. #[cfg(feature = "timing")] let start = std::time::Instant::now(); match task { Task::Regression => { let labels_train = labels_train.as_number().unwrap(); crate::regressor::compute_gradients_and_hessians( gradients.as_slice_mut().unwrap(), hessians.as_slice_mut().unwrap(), labels_train.as_slice(), predictions.column(0).as_slice().unwrap(), ); } Task::BinaryClassification => { let labels_train = labels_train.as_enum().unwrap(); crate::binary_classifier::compute_gradients_and_hessians( gradients.as_slice_mut().unwrap(), hessians.as_slice_mut().unwrap(), labels_train.as_slice(), predictions.column(0).as_slice().unwrap(), ); } Task::MulticlassClassification { .. } => { let labels_train = labels_train.as_enum().unwrap(); crate::multiclass_classifier::compute_gradients_and_hessians( tree_per_round_index, gradients.as_slice_mut().unwrap(), hessians.as_slice_mut().unwrap(), labels_train.as_slice(), predictions.view(), ); } }; #[cfg(feature = "timing")] timing.compute_gradients_and_hessians.inc(start.elapsed()); // Reset the examples_index. examples_index .as_slice_mut() .unwrap() .par_iter_mut() .enumerate() .for_each(|(index, value)| { *value = index.to_u32().unwrap(); }); // Train the tree. let tree = train_tree(TrainTreeOptions { binning_instructions: &used_features_binning_instructions, binned_features_row_major: &binned_features_row_major, binned_features_column_major: &compute_binned_features_column_major_output .binned_features, gradients: gradients.as_slice().unwrap(), hessians: hessians.as_slice().unwrap(), gradients_ordered_buffer: gradients_ordered_buffer.as_slice_mut().unwrap(), hessians_ordered_buffer: hessians_ordered_buffer.as_slice_mut().unwrap(), examples_index: examples_index.as_slice_mut().unwrap(), examples_index_left_buffer: examples_index_left_buffer.as_slice_mut().unwrap(), examples_index_right_buffer: examples_index_right_buffer.as_slice_mut().unwrap(), bin_stats_pool: &bin_stats_pool, hessians_are_constant, train_options, #[cfg(feature = "timing")] timing: &timing, }); // Update the predictions using the leaf values from the tree. update_predictions_with_tree( predictions .column_mut(tree_per_round_index) .as_slice_mut() .unwrap(), examples_index.as_slice().unwrap(), &tree, #[cfg(feature = "timing")] &timing, ); trees_for_round.push(tree); } // If loss computation is enabled, compute the loss for this round. if let Some(losses) = losses.as_mut() { let loss = match task { Task::Regression => { let labels_train = labels_train.as_number().unwrap(); let labels_train = labels_train.as_slice().into(); crate::regressor::compute_loss(predictions.view(), labels_train) } Task::BinaryClassification => { let labels_train = labels_train.as_enum().unwrap(); let labels_train = labels_train.as_slice().into(); crate::binary_classifier::compute_loss(predictions.view(), labels_train) } Task::MulticlassClassification { .. } => { let labels_train = labels_train.as_enum().unwrap(); let labels_train = labels_train.as_slice().into(); crate::multiclass_classifier::compute_loss(predictions.view(), labels_train) } }; losses.push(loss); } // If early stopping is enabled, compute the early stopping metric and update the early stopping monitor to see if we should stop training at this round. let should_stop = if early_stopping_enabled { let features_early_stopping = features_early_stopping.as_ref().unwrap(); let labels_early_stopping = labels_early_stopping.as_ref().unwrap(); let predictions_early_stopping = predictions_early_stopping.as_mut().unwrap(); let early_stopping_monitor = early_stopping_monitor.as_mut().unwrap(); let value = compute_early_stopping_metric( &task, trees_for_round.as_slice(), features_early_stopping.view(), labels_early_stopping.view(), predictions_early_stopping.view_mut(), ); early_stopping_monitor.update(value) } else { false }; // Add the trees for this round to the list of trees. trees.extend(trees_for_round); n_rounds_trained += 1; // Exit the training loop if we should stop. if should_stop { break; } // Check if we should stop training. if progress.kill_chip.is_activated() { break; } } (progress.handle_progress_event)(TrainProgressEvent::TrainDone); // Compute the feature importances. let feature_importances = Some(compute_feature_importances(&trees, n_features)); // Print out the timing and tree information if the timing feature is enabled. #[cfg(feature = "timing")] eprintln!("{:?}", timing); // Assemble the model. let trees: Vec = trees .into_iter() .map(|train_tree| { tree_from_train_tree( train_tree, compute_binned_features_column_major_output .used_feature_indexes .as_slice(), ) }) .collect(); match task { Task::Regression => TrainOutput::Regressor(RegressorTrainOutput { model: Regressor { bias: *biases.get(0).unwrap(), trees, }, feature_importances, losses, }), Task::BinaryClassification => TrainOutput::BinaryClassifier(BinaryClassifierTrainOutput { model: BinaryClassifier { bias: *biases.get(0).unwrap(), trees, }, feature_importances, losses, }), Task::MulticlassClassification { .. } => { let trees = Array2::from_shape_vec((n_rounds_trained, n_trees_per_round), trees).unwrap(); TrainOutput::MulticlassClassifier(MulticlassClassifierTrainOutput { model: MulticlassClassifier { biases, trees }, feature_importances, losses, }) } } } fn update_predictions_with_tree( predictions: &mut [f32], examples_index: &[u32], tree: &TrainTree, #[cfg(feature = "timing")] timing: &Timing, ) { #[cfg(feature = "timing")] let start = std::time::Instant::now(); struct PredictionsPtr(*mut [f32]); unsafe impl Send for PredictionsPtr {} unsafe impl Sync for PredictionsPtr {} let predictions_ptr = PredictionsPtr(predictions); tree.leaf_values.par_iter().for_each(|(range, value)| { examples_index[range.clone()] .iter() .for_each(|example_index| unsafe { let predictions = &mut *predictions_ptr.0; let example_index = example_index.to_usize().unwrap(); *predictions.get_unchecked_mut(example_index) += *value as f32; }); }); #[cfg(feature = "timing")] timing.update_predictions.inc(start.elapsed()); } #[derive(Clone)] pub struct EarlyStoppingMonitor { tolerance: f32, max_rounds_no_improve: usize, previous_stopping_metric: Option, num_rounds_no_improve: usize, } impl EarlyStoppingMonitor { /// Create a train stop monitor, pub fn new(tolerance: f32, max_rounds_no_improve: usize) -> EarlyStoppingMonitor { EarlyStoppingMonitor { tolerance, max_rounds_no_improve, previous_stopping_metric: None, num_rounds_no_improve: 0, } } /// Update with the next epoch's task metrics. Returns true if training should stop. pub fn update(&mut self, value: f32) -> bool { let stopping_metric = value; let result = if let Some(previous_stopping_metric) = self.previous_stopping_metric { if stopping_metric > previous_stopping_metric || f32::abs(stopping_metric - previous_stopping_metric) < self.tolerance { self.num_rounds_no_improve += 1; self.num_rounds_no_improve >= self.max_rounds_no_improve } else { self.num_rounds_no_improve = 0; false } } else { false }; self.previous_stopping_metric = Some(stopping_metric); result } } /// Split the feature and labels into train and early stopping datasets, where the early stopping dataset will have `early_stopping_fraction * features.nrows()` rows. fn train_early_stopping_split<'features, 'labels>( features: TableView<'features>, labels: TableColumnView<'labels>, early_stopping_fraction: f32, ) -> ( TableView<'features>, TableColumnView<'labels>, TableView<'features>, TableColumnView<'labels>, ) { let split_index = (early_stopping_fraction * labels.len().to_f32().unwrap()) .to_usize() .unwrap(); let (features_early_stopping, features_train) = features.split_at_row(split_index); let (labels_early_stopping, labels_train) = labels.split_at_row(split_index); ( features_train, labels_train, features_early_stopping, labels_early_stopping, ) } /// Compute the early stopping metric value for the set of trees that have been trained thus far. fn compute_early_stopping_metric( task: &Task, trees_for_round: &[TrainTree], features: ArrayView2, labels: TableColumnView, mut predictions: ArrayViewMut2, ) -> f32 { match task { Task::Regression => { let labels = labels.as_number().unwrap(); let labels = labels.as_slice().into(); crate::regressor::update_logits( trees_for_round, features.view(), predictions.view_mut(), ); crate::regressor::compute_loss(predictions.view(), labels) } Task::BinaryClassification => { let labels = labels.as_enum().unwrap(); let labels = labels.as_slice().into(); crate::binary_classifier::update_logits( trees_for_round, features.view(), predictions.view_mut(), ); crate::binary_classifier::compute_loss(predictions.view(), labels) } Task::MulticlassClassification { .. } => { let labels = labels.as_enum().unwrap(); let labels = labels.as_slice().into(); crate::multiclass_classifier::update_logits( trees_for_round, features.view(), predictions.view_mut(), ); crate::multiclass_classifier::compute_loss(predictions.view(), labels) } } } fn tree_from_train_tree( train_tree: TrainTree, train_feature_index_to_feature_index: &[usize], ) -> Tree { let nodes = train_tree .nodes .into_iter() .map(|node| node_from_train_node(node, train_feature_index_to_feature_index)) .collect(); Tree { nodes } } fn node_from_train_node( train_node: TrainNode, train_feature_index_to_feature_index: &[usize], ) -> Node { match train_node { TrainNode::Branch(TrainBranchNode { left_child_index, right_child_index, split, examples_fraction, .. }) => Node::Branch(BranchNode { left_child_index: left_child_index.unwrap(), right_child_index: right_child_index.unwrap(), split: match split { TrainBranchSplit::Continuous(TrainBranchSplitContinuous { feature_index, invalid_values_direction, split_value, .. }) => BranchSplit::Continuous(BranchSplitContinuous { feature_index: train_feature_index_to_feature_index[feature_index], split_value, invalid_values_direction, }), TrainBranchSplit::Discrete(TrainBranchSplitDiscrete { feature_index, directions, .. }) => BranchSplit::Discrete(BranchSplitDiscrete { feature_index: train_feature_index_to_feature_index[feature_index], directions, }), }, examples_fraction, }), TrainNode::Leaf(TrainLeafNode { value, examples_fraction, .. }) => Node::Leaf(LeafNode { value, examples_fraction, }), } }