use cetana::{ backend::DeviceManager, nn::{ activation::{ReLU, Softmax}, Layer, Linear, }, tensor::Tensor, MlResult, }; use flate2::read::GzDecoder; use std::io::Write; use std::{fs::File, io::Read, path::Path, time::Instant}; type Float = f32; struct MnistConfig { learning_rate: Float, epochs: usize, display_interval: usize, early_stopping_patience: usize, early_stopping_min_delta: Float, } impl Default for MnistConfig { fn default() -> Self { Self { learning_rate: 0.0005, epochs: 500, display_interval: 1, early_stopping_patience: 50, early_stopping_min_delta: 1e-6, } } } struct MnistClassifier { fc1: Linear, relu1: ReLU, fc2: Linear, relu2: ReLU, fc3: Linear, softmax: Softmax, } impl MnistClassifier { fn new() -> MlResult { Ok(Self { // 784 = 28*28 input image size fc1: Linear::new(784, 128, true)?, relu1: ReLU::new(), fc2: Linear::new(128, 64, true)?, relu2: ReLU::new(), fc3: Linear::new(64, 10, true)?, // 10 output classes softmax: Softmax::new(), }) } fn forward(&mut self, x: &Tensor) -> MlResult { // Reshape input from [batch_size, 1, 28, 28] to [batch_size, 784] let batch_size = x.shape()[0]; let flattened = x.reshape(&[batch_size, 784])?; let h1 = self.fc1.forward(&flattened)?; let h1 = self.relu1.forward(&h1)?; let h2 = self.fc2.forward(&h1)?; let h2 = self.relu2.forward(&h2)?; let out = self.fc3.forward(&h2)?; self.softmax.forward(&out) } fn train_step(&mut self, x: &Tensor, y: &Tensor, learning_rate: f32) -> MlResult { // Forward pass let batch_size = x.shape()[0]; let flattened = x.reshape(&[batch_size, 784])?; let h1 = self.fc1.forward(&flattened)?; let h1_activated = self.relu1.forward(&h1)?; let h2 = self.fc2.forward(&h1_activated)?; let h2_activated = self.relu2.forward(&h2)?; let out = self.fc3.forward(&h2_activated)?; let predictions = self.softmax.forward(&out)?; // Compute cross-entropy loss let epsilon = 1e-15; let clipped_preds = predictions.clip(epsilon, 1.0 - epsilon)?; let log_preds = clipped_preds.log()?; let batch_loss = log_preds.mul(y)?.sum(1)?.mul_scalar(-1.0)?; // Backward pass let grad_output = predictions.sub(y)?; let h2_grad = self .fc3 .backward(&h2_activated, &grad_output, learning_rate)?; let h2_grad = self.relu2.backward(&h2, &h2_grad, learning_rate)?; let h1_grad = self.fc2.backward(&h1_activated, &h2_grad, learning_rate)?; let h1_grad = self.relu1.backward(&h1, &h1_grad, learning_rate)?; self.fc1.backward(&flattened, &h1_grad, learning_rate)?; batch_loss.mean() } fn evaluate(&mut self, x: &Tensor, y: &Tensor) -> MlResult<(f32, f32)> { let predictions = self.forward(x)?; // Calculate accuracy by comparing max indices let mut correct = 0; let batch_size = predictions.shape()[0]; for i in 0..batch_size { // Get predicted class (max value index) let mut max_pred_idx = 0; let mut max_pred_val = f32::NEG_INFINITY; for j in 0..10 { let val = predictions.data()[i * 10 + j]; if val > max_pred_val { max_pred_val = val; max_pred_idx = j; } } // Get true class (max value index) let mut max_true_idx = 0; let mut max_true_val = f32::NEG_INFINITY; for j in 0..10 { let val = y.data()[i * 10 + j]; if val > max_true_val { max_true_val = val; max_true_idx = j; } } if max_pred_idx == max_true_idx { correct += 1; } } let accuracy = (correct as f32) * 100.0 / (batch_size as f32); // Calculate loss let epsilon = 1e-15; let clipped_preds = predictions.clip(epsilon, 1.0 - epsilon)?; let log_preds = clipped_preds.log()?; let batch_loss = log_preds.mul(y)?.sum(1)?.mul_scalar(-1.0)?; Ok((accuracy, batch_loss.mean()?)) } } fn load_mnist_data(path: &str) -> MlResult<(Vec, usize)> { let mut file = File::open(Path::new(path)).map_err(|e| format!("Failed to open MNIST file: {}", e))?; let mut buffer = Vec::new(); file.read_to_end(&mut buffer) .map_err(|e| format!("Failed to read MNIST file: {}", e))?; // MNIST magic number check and size extraction let magic_number = u32::from_be_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); let num_items = u32::from_be_bytes([buffer[4], buffer[5], buffer[6], buffer[7]]) as usize; // Skip header (16 bytes for images, 8 bytes for labels) let data_start = if magic_number == 2051 { 16 } else { 8 }; let data = buffer[data_start..].to_vec(); Ok((data, num_items)) } async fn download_mnist_data() -> MlResult<()> { let base_url = "https://raw.githubusercontent.com/fgnt/mnist/master/"; let files = [ "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", ]; // Create data directory if it doesn't exist std::fs::create_dir_all("data/mnist") .map_err(|e| format!("Failed to create data directory: {}", e))?; for file in files.iter() { let url = format!("{}{}", base_url, file); let output_path = format!("data/mnist/{}", file.replace(".gz", "")); // Skip if file already exists if Path::new(&output_path).exists() { println!("File {} already exists, skipping download", output_path); continue; } println!("Downloading {} from mirror...", file); // Download file with retry logic let client = reqwest::Client::new(); let response = client .get(&url) .header("User-Agent", "Mozilla/5.0") // Add user agent to avoid potential blocks .send() .await .map_err(|e| format!("Failed to download {}: {}", file, e))?; if !response.status().is_success() { return Err(format!( "Failed to download {}: HTTP status {}", file, response.status() ) .into()); } let content = response .bytes() .await .map_err(|e| format!("Failed to read response for {}: {}", file, e))?; // Decompress gz file let mut decoder = GzDecoder::new(&content[..]); let mut decompressed = Vec::new(); decoder .read_to_end(&mut decompressed) .map_err(|e| format!("Failed to decompress {}: {}", file, e))?; // Write to file let mut output_file = File::create(&output_path) .map_err(|e| format!("Failed to create output file {}: {}", output_path, e))?; output_file .write_all(&decompressed) .map_err(|e| format!("Failed to write to {}: {}", output_path, e))?; println!("Successfully downloaded and extracted {}", file); } Ok(()) } #[tokio::main] async fn main() -> MlResult<()> { cetana::log::init().expect("Failed to initialize logger"); println!("MNIST Digit Classification Example\n"); // Download MNIST data if needed println!("Checking/Downloading MNIST dataset..."); download_mnist_data().await?; println!("Dataset ready!\n"); // Initialize device and model let device_manager = DeviceManager::new(); let device = device_manager.select_device(None)?; DeviceManager::set_default_device(device)?; let mut model = MnistClassifier::new()?; let config = MnistConfig::default(); // Load and prepare data let (train_images, num_train) = load_mnist_data("data/mnist/train-images-idx3-ubyte")?; let (train_labels, _) = load_mnist_data("data/mnist/train-labels-idx1-ubyte")?; // Prepare training data let x_train = Tensor::from_vec( train_images .iter() .map(|&x| (x as Float / 127.5) - 1.0) // Normalize to [-1, 1] .collect(), &[num_train, 1, 28, 28], )?; let mut y_train_data = vec![0.0; num_train * 10]; for (i, &label) in train_labels.iter().enumerate() { y_train_data[i * 10 + label as usize] = 1.0; } let y_train = Tensor::from_vec(y_train_data, &[num_train, 10])?; println!("Training Configuration:"); println!("Learning Rate: {}", config.learning_rate); println!("Epochs: {}", config.epochs); println!( "Early Stopping Patience: {}", config.early_stopping_patience ); println!( "Early Stopping Min Delta: {}", config.early_stopping_min_delta ); println!("\nTraining Progress:"); let start_time = Instant::now(); let mut best_loss = Float::MAX; let mut patience_counter = 0; let mut losses = Vec::new(); // Training loop with error handling for epoch in 0..config.epochs { match model.train_step(&x_train, &y_train, config.learning_rate) { Ok(loss) => { losses.push(loss); if (best_loss - loss) > config.early_stopping_min_delta { best_loss = loss; patience_counter = 0; } else { patience_counter += 1; if patience_counter >= config.early_stopping_patience { println!("\nEarly stopping triggered at epoch {}", epoch); break; } } if epoch % config.display_interval == 0 { let (train_accuracy, _) = model.evaluate(&x_train, &y_train)?; println!( "Epoch {}/{}: Loss = {:.6}, Accuracy = {:.1}%", epoch + 1, config.epochs, loss, train_accuracy ); } } Err(e) => { println!("Training error at epoch {}: {}", epoch + 1, e); break; } } } let training_time = start_time.elapsed(); println!("\nTraining Complete!"); println!("Training time: {:.2?}", training_time); // Final evaluation let (final_accuracy, final_loss) = model.evaluate(&x_train, &y_train)?; println!("\nFinal Results:"); println!("Accuracy: {:.1}%", final_accuracy); println!("Loss: {:.6}", final_loss); Ok(()) }