use crate::{ config, features::{choose_feature_groups_linear, choose_feature_groups_tree}, stats::ColumnStatsOutput, }; use itertools::iproduct; /// A `GridItem` is a description of a single entry in a hyperparameter grid. It specifies what feature engineering to perform on the training data, which model to train, and which hyperparameters to use. #[derive(Clone, Debug)] pub enum GridItem { LinearRegressor { target_column_index: usize, feature_groups: Vec, options: LinearModelTrainOptions, }, TreeRegressor { target_column_index: usize, feature_groups: Vec, options: TreeModelTrainOptions, }, LinearBinaryClassifier { target_column_index: usize, feature_groups: Vec, options: LinearModelTrainOptions, }, TreeBinaryClassifier { target_column_index: usize, feature_groups: Vec, options: TreeModelTrainOptions, }, LinearMulticlassClassifier { target_column_index: usize, feature_groups: Vec, options: LinearModelTrainOptions, }, TreeMulticlassClassifier { target_column_index: usize, feature_groups: Vec, options: TreeModelTrainOptions, }, } #[derive(Clone, Debug)] pub struct LinearModelTrainOptions { pub l2_regularization: Option, pub learning_rate: Option, pub max_epochs: Option, pub n_examples_per_batch: Option, pub early_stopping_options: Option, } impl Default for LinearModelTrainOptions { fn default() -> LinearModelTrainOptions { LinearModelTrainOptions { l2_regularization: None, learning_rate: None, max_epochs: None, n_examples_per_batch: None, early_stopping_options: None, } } } #[derive(Clone, Debug)] pub struct TreeModelTrainOptions { pub binned_features_layout: Option, pub early_stopping_options: Option, pub l2_regularization_for_continuous_splits: Option, pub l2_regularization_for_discrete_splits: Option, pub learning_rate: Option, pub max_depth: Option, pub max_examples_for_computing_bin_thresholds: Option, pub max_leaf_nodes: Option, pub max_rounds: Option, pub max_valid_bins_for_number_features: Option, pub min_examples_per_node: Option, pub min_gain_to_split: Option, pub min_sum_hessians_per_node: Option, pub smoothing_factor_for_discrete_bin_sorting: Option, } impl Default for TreeModelTrainOptions { fn default() -> TreeModelTrainOptions { TreeModelTrainOptions { binned_features_layout: None, early_stopping_options: None, l2_regularization_for_continuous_splits: None, l2_regularization_for_discrete_splits: None, learning_rate: None, max_depth: None, max_examples_for_computing_bin_thresholds: None, max_leaf_nodes: None, max_rounds: None, max_valid_bins_for_number_features: None, min_examples_per_node: None, min_gain_to_split: None, min_sum_hessians_per_node: None, smoothing_factor_for_discrete_bin_sorting: None, } } } #[derive(Clone, Debug)] pub enum BinnedFeaturesLayout { RowMajor, ColumnMajor, } #[derive(Clone, Debug)] pub struct EarlyStoppingOptions { pub early_stopping_fraction: f32, pub early_stopping_rounds: usize, pub early_stopping_threshold: f32, } impl Default for EarlyStoppingOptions { fn default() -> Self { EarlyStoppingOptions { early_stopping_fraction: 0.1, early_stopping_rounds: 5, early_stopping_threshold: 1e-5, } } } pub fn compute_regression_hyperparameter_grid( grid: &[config::GridItem], target_column_index: usize, column_stats: &[ColumnStatsOutput], config: &config::Config, ) -> Vec { grid.iter() .map(|item| match item { config::GridItem::Linear(item) => GridItem::LinearRegressor { target_column_index, feature_groups: choose_feature_groups_linear(column_stats, config), options: LinearModelTrainOptions { l2_regularization: item.l2_regularization, learning_rate: item.learning_rate, max_epochs: item.max_epochs, n_examples_per_batch: item.n_examples_per_batch, early_stopping_options: item.early_stopping_options.as_ref().map( |early_stopping_options| EarlyStoppingOptions { early_stopping_fraction: early_stopping_options.early_stopping_fraction, early_stopping_rounds: early_stopping_options .n_rounds_without_improvement_to_stop, early_stopping_threshold: early_stopping_options .min_decrease_in_loss_for_significant_change, }, ), }, }, config::GridItem::Tree(item) => GridItem::TreeRegressor { target_column_index, feature_groups: choose_feature_groups_tree(column_stats, config), options: TreeModelTrainOptions { binned_features_layout: item.binned_features_layout.as_ref().map( |binned_feature_layout| match binned_feature_layout { config::BinnedFeaturesLayout::RowMajor => { BinnedFeaturesLayout::RowMajor } config::BinnedFeaturesLayout::ColumnMajor => { BinnedFeaturesLayout::ColumnMajor } }, ), early_stopping_options: item.early_stopping_options.as_ref().map( |early_stopping_options| EarlyStoppingOptions { early_stopping_fraction: early_stopping_options.early_stopping_fraction, early_stopping_rounds: early_stopping_options .n_rounds_without_improvement_to_stop, early_stopping_threshold: early_stopping_options .min_decrease_in_loss_for_significant_change, }, ), l2_regularization_for_continuous_splits: item .l2_regularization_for_continuous_splits, l2_regularization_for_discrete_splits: item .l2_regularization_for_discrete_splits, learning_rate: item.learning_rate, max_depth: item.max_depth, max_examples_for_computing_bin_thresholds: item .max_examples_for_computing_bin_thresholds, max_leaf_nodes: item.max_leaf_nodes, max_rounds: item.max_rounds, max_valid_bins_for_number_features: item.max_valid_bins_for_number_features, min_examples_per_node: item.min_examples_per_node, min_gain_to_split: item.min_gain_to_split, min_sum_hessians_per_node: item.min_sum_hessians_per_node, smoothing_factor_for_discrete_bin_sorting: item .smoothing_factor_for_discrete_bin_sorting, }, }, }) .collect() } pub fn compute_binary_classification_hyperparameter_grid( grid: &[config::GridItem], target_column_index: usize, column_stats: &[ColumnStatsOutput], config: &config::Config, ) -> Vec { grid.iter() .map(|item| match item { config::GridItem::Linear(item) => GridItem::LinearBinaryClassifier { target_column_index, feature_groups: choose_feature_groups_linear(column_stats, config), options: LinearModelTrainOptions { l2_regularization: item.l2_regularization, learning_rate: item.learning_rate, max_epochs: item.max_epochs, n_examples_per_batch: item.n_examples_per_batch, early_stopping_options: item.early_stopping_options.as_ref().map( |early_stopping_options| EarlyStoppingOptions { early_stopping_fraction: early_stopping_options.early_stopping_fraction, early_stopping_rounds: early_stopping_options .n_rounds_without_improvement_to_stop, early_stopping_threshold: early_stopping_options .min_decrease_in_loss_for_significant_change, }, ), }, }, config::GridItem::Tree(item) => GridItem::TreeBinaryClassifier { target_column_index, feature_groups: choose_feature_groups_tree(column_stats, config), options: TreeModelTrainOptions { binned_features_layout: item.binned_features_layout.as_ref().map( |binned_feature_layout| match binned_feature_layout { config::BinnedFeaturesLayout::RowMajor => { BinnedFeaturesLayout::RowMajor } config::BinnedFeaturesLayout::ColumnMajor => { BinnedFeaturesLayout::ColumnMajor } }, ), early_stopping_options: item.early_stopping_options.as_ref().map( |early_stopping_options| EarlyStoppingOptions { early_stopping_fraction: early_stopping_options.early_stopping_fraction, early_stopping_rounds: early_stopping_options .n_rounds_without_improvement_to_stop, early_stopping_threshold: early_stopping_options .min_decrease_in_loss_for_significant_change, }, ), l2_regularization_for_continuous_splits: item .l2_regularization_for_continuous_splits, l2_regularization_for_discrete_splits: item .l2_regularization_for_discrete_splits, learning_rate: item.learning_rate, max_depth: item.max_depth, max_examples_for_computing_bin_thresholds: item .max_examples_for_computing_bin_thresholds, max_leaf_nodes: item.max_leaf_nodes, max_rounds: item.max_rounds, max_valid_bins_for_number_features: item.max_valid_bins_for_number_features, min_examples_per_node: item.min_examples_per_node, min_gain_to_split: item.min_gain_to_split, min_sum_hessians_per_node: item.min_sum_hessians_per_node, smoothing_factor_for_discrete_bin_sorting: item .smoothing_factor_for_discrete_bin_sorting, }, }, }) .collect() } pub fn compute_multiclass_classification_hyperparameter_grid( grid: &[config::GridItem], target_column_index: usize, column_stats: &[ColumnStatsOutput], config: &config::Config, ) -> Vec { grid.iter() .map(|item| match item { config::GridItem::Linear(item) => GridItem::LinearMulticlassClassifier { target_column_index, feature_groups: choose_feature_groups_linear(column_stats, config), options: LinearModelTrainOptions { l2_regularization: item.l2_regularization, learning_rate: item.learning_rate, max_epochs: item.max_epochs, n_examples_per_batch: item.n_examples_per_batch, early_stopping_options: item.early_stopping_options.as_ref().map( |early_stopping_options| EarlyStoppingOptions { early_stopping_fraction: early_stopping_options.early_stopping_fraction, early_stopping_rounds: early_stopping_options .n_rounds_without_improvement_to_stop, early_stopping_threshold: early_stopping_options .min_decrease_in_loss_for_significant_change, }, ), }, }, config::GridItem::Tree(item) => GridItem::TreeMulticlassClassifier { target_column_index, feature_groups: choose_feature_groups_tree(column_stats, config), options: TreeModelTrainOptions { binned_features_layout: item.binned_features_layout.as_ref().map( |binned_feature_layout| match binned_feature_layout { config::BinnedFeaturesLayout::RowMajor => { BinnedFeaturesLayout::RowMajor } config::BinnedFeaturesLayout::ColumnMajor => { BinnedFeaturesLayout::ColumnMajor } }, ), early_stopping_options: item.early_stopping_options.as_ref().map( |early_stopping_options| EarlyStoppingOptions { early_stopping_fraction: early_stopping_options.early_stopping_fraction, early_stopping_rounds: early_stopping_options .n_rounds_without_improvement_to_stop, early_stopping_threshold: early_stopping_options .min_decrease_in_loss_for_significant_change, }, ), l2_regularization_for_continuous_splits: item .l2_regularization_for_continuous_splits, l2_regularization_for_discrete_splits: item .l2_regularization_for_discrete_splits, learning_rate: item.learning_rate, max_depth: item.max_depth, max_examples_for_computing_bin_thresholds: item .max_examples_for_computing_bin_thresholds, max_leaf_nodes: item.max_leaf_nodes, max_rounds: item.max_rounds, max_valid_bins_for_number_features: item.max_valid_bins_for_number_features, min_examples_per_node: item.min_examples_per_node, min_gain_to_split: item.min_gain_to_split, min_sum_hessians_per_node: item.min_sum_hessians_per_node, smoothing_factor_for_discrete_bin_sorting: item .smoothing_factor_for_discrete_bin_sorting, }, }, }) .collect() } const DEFAULT_LINEAR_MODEL_LEARNING_RATE_VALUES: [f32; 2] = [0.1, 0.01]; const DEFAULT_LINEAR_L2_REGULARIZATION_VALUES: [f32; 2] = [1.0, 0.1]; const DEFAULT_LINEAR_MAX_EPOCHS_VALUES: [u64; 1] = [1000]; const DEFAULT_LINEAR_N_EXAMPLES_PER_BATCH_VALUES: [u64; 1] = [128]; const DEFAULT_TREE_LEARNING_RATE_VALUES: [f32; 2] = [0.1, 0.01]; const DEFAULT_TREE_L2_REGULARIZATION_VALUES_FOR_CONTINUOUS_SPLITS: [f32; 2] = [1.0, 0.1]; const DEFAULT_TREE_MAX_LEAF_NODES: [u64; 1] = [512]; const DEFAULT_TREE_MAX_ROUNDS_VALUES: [u64; 1] = [1000]; const DEFAULT_TREE_MAX_DEPTH: [u64; 1] = [50]; /// Compute the default hyperparameter grid for regression. pub fn default_regression_hyperparameter_grid( target_column_index: usize, column_stats: &[ColumnStatsOutput], config: &config::Config, ) -> Vec { let mut grid = Vec::new(); for (&l2_regularization, &learning_rate, &max_epochs, &n_examples_per_batch) in iproduct!( DEFAULT_LINEAR_L2_REGULARIZATION_VALUES.iter(), DEFAULT_LINEAR_MODEL_LEARNING_RATE_VALUES.iter(), DEFAULT_LINEAR_MAX_EPOCHS_VALUES.iter(), DEFAULT_LINEAR_N_EXAMPLES_PER_BATCH_VALUES.iter() ) { grid.push(GridItem::LinearRegressor { target_column_index, feature_groups: choose_feature_groups_linear(column_stats, config), options: LinearModelTrainOptions { l2_regularization: Some(l2_regularization), learning_rate: Some(learning_rate), max_epochs: Some(max_epochs), n_examples_per_batch: Some(n_examples_per_batch), early_stopping_options: Some(Default::default()), }, }); } for ( &max_leaf_nodes, &learning_rate, &l2_regularization_for_continuous_splits, &max_rounds, &max_depth, ) in iproduct!( DEFAULT_TREE_MAX_LEAF_NODES.iter(), DEFAULT_TREE_LEARNING_RATE_VALUES.iter(), DEFAULT_TREE_L2_REGULARIZATION_VALUES_FOR_CONTINUOUS_SPLITS.iter(), DEFAULT_TREE_MAX_ROUNDS_VALUES.iter(), DEFAULT_TREE_MAX_DEPTH.iter() ) { grid.push(GridItem::TreeRegressor { target_column_index, feature_groups: choose_feature_groups_tree(column_stats, config), options: TreeModelTrainOptions { max_leaf_nodes: Some(max_leaf_nodes), learning_rate: Some(learning_rate), max_rounds: Some(max_rounds), max_depth: Some(max_depth), l2_regularization_for_continuous_splits: Some( l2_regularization_for_continuous_splits, ), early_stopping_options: Some(Default::default()), ..Default::default() }, }); } grid } /// Compute the default hyperparameter grid for binary classification. pub fn default_binary_classification_hyperparameter_grid( target_column_index: usize, column_stats: &[ColumnStatsOutput], config: &config::Config, ) -> Vec { let mut grid = Vec::new(); for (&l2_regularization, &learning_rate, &max_epochs, &n_examples_per_batch) in iproduct!( DEFAULT_LINEAR_L2_REGULARIZATION_VALUES.iter(), DEFAULT_LINEAR_MODEL_LEARNING_RATE_VALUES.iter(), DEFAULT_LINEAR_MAX_EPOCHS_VALUES.iter(), DEFAULT_LINEAR_N_EXAMPLES_PER_BATCH_VALUES.iter() ) { grid.push(GridItem::LinearBinaryClassifier { target_column_index, feature_groups: choose_feature_groups_linear(column_stats, config), options: LinearModelTrainOptions { l2_regularization: Some(l2_regularization), learning_rate: Some(learning_rate), max_epochs: Some(max_epochs), n_examples_per_batch: Some(n_examples_per_batch), early_stopping_options: Some(Default::default()), }, }); } for ( &max_leaf_nodes, &learning_rate, &l2_regularization_for_continous_splits, &max_rounds, &max_depth, ) in iproduct!( DEFAULT_TREE_MAX_LEAF_NODES.iter(), DEFAULT_TREE_LEARNING_RATE_VALUES.iter(), DEFAULT_TREE_L2_REGULARIZATION_VALUES_FOR_CONTINUOUS_SPLITS.iter(), DEFAULT_TREE_MAX_ROUNDS_VALUES.iter(), DEFAULT_TREE_MAX_DEPTH.iter() ) { grid.push(GridItem::TreeBinaryClassifier { target_column_index, feature_groups: choose_feature_groups_tree(column_stats, config), options: TreeModelTrainOptions { max_leaf_nodes: Some(max_leaf_nodes), learning_rate: Some(learning_rate), max_rounds: Some(max_rounds), max_depth: Some(max_depth), l2_regularization_for_continuous_splits: Some( l2_regularization_for_continous_splits, ), early_stopping_options: Some(Default::default()), ..Default::default() }, }); } grid } /// Compute the default hyperparameter grid for multiclass classification. pub fn default_multiclass_classification_hyperparameter_grid( target_column_index: usize, column_stats: &[ColumnStatsOutput], config: &config::Config, ) -> Vec { let mut grid = Vec::new(); for (&l2_regularization, &learning_rate, &max_epochs, &n_examples_per_batch) in iproduct!( DEFAULT_LINEAR_L2_REGULARIZATION_VALUES.iter(), DEFAULT_LINEAR_MODEL_LEARNING_RATE_VALUES.iter(), DEFAULT_LINEAR_MAX_EPOCHS_VALUES.iter(), DEFAULT_LINEAR_N_EXAMPLES_PER_BATCH_VALUES.iter() ) { grid.push(GridItem::LinearMulticlassClassifier { target_column_index, feature_groups: choose_feature_groups_linear(column_stats, config), options: LinearModelTrainOptions { l2_regularization: Some(l2_regularization), learning_rate: Some(learning_rate), max_epochs: Some(max_epochs), n_examples_per_batch: Some(n_examples_per_batch), early_stopping_options: Some(Default::default()), }, }); } for ( &max_leaf_nodes, &learning_rate, &l2_regularization_for_continuous_splits, &max_rounds, &max_depth, ) in iproduct!( DEFAULT_TREE_MAX_LEAF_NODES.iter(), DEFAULT_TREE_LEARNING_RATE_VALUES.iter(), DEFAULT_TREE_L2_REGULARIZATION_VALUES_FOR_CONTINUOUS_SPLITS.iter(), DEFAULT_TREE_MAX_ROUNDS_VALUES.iter(), DEFAULT_TREE_MAX_DEPTH.iter() ) { grid.push(GridItem::TreeMulticlassClassifier { target_column_index, feature_groups: choose_feature_groups_tree(column_stats, config), options: TreeModelTrainOptions { max_leaf_nodes: Some(max_leaf_nodes), learning_rate: Some(learning_rate), max_rounds: Some(max_rounds), max_depth: Some(max_depth), l2_regularization_for_continuous_splits: Some( l2_regularization_for_continuous_splits, ), early_stopping_options: Some(Default::default()), ..Default::default() }, }); } grid }