use rusty_machine::linalg::{Matrix, BaseMatrix}; use rusty_machine::learning::k_means::KMeansClassifier; use rusty_machine::learning::UnSupModel; use rand::thread_rng; use rand::distributions::IndependentSample; use rand::distributions::normal::Normal; use test::{Bencher, black_box}; fn generate_data(centroids: &Matrix, points_per_centroid: usize, noise: f64) -> Matrix { assert!(centroids.cols() > 0, "Centroids cannot be empty."); assert!(centroids.rows() > 0, "Centroids cannot be empty."); assert!(noise >= 0f64, "Noise must be non-negative."); let mut raw_cluster_data = Vec::with_capacity(centroids.rows() * points_per_centroid * centroids.cols()); let mut rng = thread_rng(); let normal_rv = Normal::new(0f64, noise); for _ in 0..points_per_centroid { // Generate points from each centroid for centroid in centroids.iter_rows() { // Generate a point randomly around the centroid let mut point = Vec::with_capacity(centroids.cols()); for feature in centroid { point.push(feature + normal_rv.ind_sample(&mut rng)); } // Push point to raw_cluster_data raw_cluster_data.extend(point); } } Matrix::new(centroids.rows() * points_per_centroid, centroids.cols(), raw_cluster_data) } #[bench] fn k_means_train(b: &mut Bencher) { const SAMPLES_PER_CENTROID: usize = 2000; // Choose two cluster centers, at (-0.5, -0.5) and (0, 0.5). let centroids = Matrix::new(2, 2, vec![-0.5, -0.5, 0.0, 0.5]); // Generate some data randomly around the centroids let samples = generate_data(¢roids, SAMPLES_PER_CENTROID, 0.4); b.iter(|| { let mut model = black_box(KMeansClassifier::new(2)); let _ = black_box(model.train(&samples).unwrap()); }); } #[bench] fn k_means_predict(b: &mut Bencher) { const SAMPLES_PER_CENTROID: usize = 2000; // Choose two cluster centers, at (-0.5, -0.5) and (0, 0.5). let centroids = Matrix::new(2, 2, vec![-0.5, -0.5, 0.0, 0.5]); // Generate some data randomly around the centroids let samples = generate_data(¢roids, SAMPLES_PER_CENTROID, 0.4); let mut model = KMeansClassifier::new(2); let _ = model.train(&samples).unwrap(); b.iter(|| { let _ = black_box(model.centroids().as_ref().unwrap()); let _ = black_box(model.predict(&samples).unwrap()); }); }