use plotters::prelude::*; use tinguely::{UnsupervisedLearn}; use tinguely::clustering::{KMeans, KMeansInitializer}; use mathru::algebra::linear::{Vector, Matrix}; use mathru::matrix; use mathru::statistics::distrib::{Normal, Distribution}; fn generate_data(centroids: &Matrix, points_per_centroid: usize, sigma: f64) -> Matrix { let (rows, cols) = centroids.dim(); assert!(cols > 0, "Centroids cannot be empty."); assert!(rows > 0, "Centroids cannot be empty."); assert!(sigma >= 0.0f64, "Noise must be non-negative."); let mut raw_cluster_data = Vec::with_capacity(rows * points_per_centroid * cols); let normal: Normal = Normal::new(0.0f64, sigma); // Generate points for each centroid for column in centroids.column_into_iter() { for _ in 0..points_per_centroid { // Generate a point randomly around the centroid let mut point: Vec = Vec::with_capacity(rows); for feature in column.iter() { point.push(feature + normal.random()); } // Push point to raw_cluster_data raw_cluster_data.extend(point); } } Matrix::new(centroids.nrows() * points_per_centroid, centroids.ncols(), raw_cluster_data) } fn main() { let centroids: Matrix = matrix![1.0, 1.0; 2.5, 3.5; 3.5, 1.5]; let data: Matrix = generate_data(¢roids, 100, 0.1); let mut kmeans: KMeans = KMeans::new(3, 200, KMeansInitializer::Random); kmeans.train(&data); let pred: Vector = kmeans.predict(&data); let root_area = BitMapBackend::new("./figures/kmeans.png", (600, 400)).into_drawing_area(); root_area.fill(&WHITE).unwrap(); let mut ctx = ChartBuilder::on(&root_area) .margin(20) .set_label_area_size(LabelAreaPosition::Left, 40) .set_label_area_size(LabelAreaPosition::Bottom, 40) .build_cartesian_2d(0.0..4.0, 0.0..4.0) .unwrap(); ctx.configure_mesh() .x_desc("x") .y_desc("y") .axis_desc_style(("sans-serif", 15).into_font()) .draw() .unwrap(); ctx.draw_series( centroids .row_into_iter() .map(|coord| Circle::new((*coord.get(0), *coord.get(1)), 3, BLACK.filled())), ).unwrap(); ctx.draw_series( kmeans.centroids() .row_into_iter() .map(|coord| Circle::new((*coord.get(0), *coord.get(1)), 3, GREEN.filled())), ).unwrap(); ctx.draw_series( data.row_into_iter().zip(pred.iter()) .map(|(coord, class)| { let color: RGBColor = match *class { 0.0 => MAGENTA, 1.0 => BLUE, _ => RED }; Cross::new((*coord.get(0), *coord.get(1)), 2, color.filled()) } ) ).unwrap(); }