use burn::{ data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, prelude::*, }; #[derive(Clone, Debug)] pub struct MnistBatcher { device: B::Device, } #[derive(Clone, Debug)] pub struct MnistBatch { pub images: Tensor, pub targets: Tensor, } impl MnistBatcher { pub fn new(device: B::Device) -> Self { Self { device } } } impl Batcher> for MnistBatcher { fn batch(&self, items: Vec) -> MnistBatch { let images = items .iter() .map(|item| Data::::from(item.image)) .map(|data| Tensor::::from_data(data.convert(), &self.device)) .map(|tensor| tensor.reshape([1, 28, 28])) // normalize: make between [0,1] and make the mean = 0 and std = 1 // values mean=0.1307,std=0.3081 were copied from Pytorch Mnist Example // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122 .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081) .collect(); let targets = items .iter() .map(|item| { Tensor::::from_data( Data::from([(item.label as i64).elem()]), &self.device, ) }) .collect(); let images = Tensor::cat(images, 0); let targets = Tensor::cat(targets, 0); MnistBatch { images, targets } } }