use ndarray::prelude::*; use serde_json::json; use std::path::Path; use tangram_linear::Progress; use tangram_table::prelude::*; use tangram_zip::zip; fn main() { // Load the data. let csv_file_path_train = Path::new("data/heart_disease_train.csv"); let csv_file_path_test = Path::new("data/heart_disease_test.csv"); let target_column_index = 13; let mut features_train = Table::from_path(csv_file_path_train, Default::default(), &mut |_| {}).unwrap(); let labels_train = features_train.columns_mut().remove(target_column_index); let labels_train = labels_train.as_enum().unwrap(); let mut features_test = Table::from_path(csv_file_path_test, Default::default(), &mut |_| {}).unwrap(); let labels_test = features_test.columns_mut().remove(target_column_index); let labels_test = labels_test.as_enum().unwrap(); let feature_groups: Vec = features_train .columns() .iter() .map(|column| { tangram_features::FeatureGroup::Normalized( tangram_features::NormalizedFeatureGroup::compute_for_column(column.view()), ) }) .collect(); let features_train = tangram_features::compute_features_array_f32( &features_train.view(), feature_groups.as_slice(), &|| {}, ); let features_test = tangram_features::compute_features_array_f32( &features_test.view(), feature_groups.as_slice(), &|| {}, ); // Train the model. let train_output = tangram_linear::BinaryClassifier::train( features_train.view(), labels_train.view(), &tangram_linear::TrainOptions { learning_rate: 0.01, max_epochs: 1, n_examples_per_batch: 1, ..Default::default() }, Progress { kill_chip: &tangram_kill_chip::KillChip::default(), handle_progress_event: &mut |_| {}, }, ); // Make predictions on the test data. let mut probabilities = Array::zeros(labels_test.len()); train_output .model .predict(features_test.view(), probabilities.view_mut()); // Compute metrics. let input = zip!(probabilities.iter(), labels_test.iter()) .map(|(probability, label)| (*probability, label.unwrap())) .collect(); let auc_roc = tangram_metrics::AucRoc::compute(input); let output = json!({ "auc_roc": auc_roc, }); println!("{}", output); }