// Copyright 2021 Jonathan Manly. // This file is part of . // rml is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // rml is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // You should have received a copy of the GNU Lesser General Public License // along with rml. If not, see . use rml::knn; use rml::math; use std::error::Error; use std::time::Instant; const TRAIN_FILE_NAME: &str = "./data/optdigits.tra"; const TEST_FILE_NAME: &str = "./data/optdigits.tes"; type CSVOutput = (Vec>, Vec); fn parse_csv(data: &str) -> Result> { let mut out_data: CSVOutput = (Vec::new(), Vec::new()); let mut reader = csv::ReaderBuilder::new() .has_headers(false) .from_path(data)?; for line in reader.records() { let result = line?; let mut line_data: (Vec, i32) = (Vec::new(), 0); line_data.1 = (result.get(result.len() - 1).unwrap()).parse()?; for i in 0..result.len() - 1 { line_data.0.push((result.get(i).unwrap()).parse()?); } out_data.0.push(line_data.0); out_data.1.push(line_data.1); } Ok(out_data) } fn main() -> Result<(), Box> { // Format: (Vectors of each feature, Vector of class label) let training_data = parse_csv(TRAIN_FILE_NAME)?; let testing_data = parse_csv(TEST_FILE_NAME)?; let start = Instant::now(); let knn = knn::KNN::new( 5, training_data.0, training_data.1, None, Some(math::norm::Norm::L2), ); let pred: Vec = testing_data.0.iter().map(|x| knn.predict(x)).collect(); let num_correct = pred .iter() .cloned() .zip(&testing_data.1) .filter(|(a, b)| *a == **b) .count(); println!( "Accuracy: {} Runtime: {}s", (num_correct as f64) / (pred.len() as f64), start.elapsed().as_secs_f64() ); Ok(()) }