#[cfg(feature = "std")] mod tests { use burn::{ module::Module, nn, record::{ BinFileRecorder, DefaultFileRecorder, FileRecorder, FullPrecisionSettings, PrettyJsonFileRecorder, RecorderError, }, }; use burn_core as burn; use burn_ndarray::NdArrayDevice; use burn_tensor::backend::Backend; use std::path::PathBuf; type TestBackend = burn_ndarray::NdArray; #[derive(Module, Debug)] pub struct Model { single_const: f32, linear1: nn::Linear, array_const: [usize; 2], linear2: nn::Linear, } #[derive(Module, Debug)] pub struct ModelNewOptionalField { single_const: f32, linear1: nn::Linear, array_const: [usize; 2], linear2: nn::Linear, new_field: Option, } #[derive(Module, Debug)] pub struct ModelNewConstantField { single_const: f32, linear1: nn::Linear, array_const: [usize; 2], linear2: nn::Linear, new_field: usize, } #[derive(Module, Debug)] pub struct ModelNewFieldOrders { array_const: [usize; 2], linear2: nn::Linear, single_const: f32, linear1: nn::Linear, } #[test] fn deserialize_with_new_optional_field_works_with_default_file_recorder() { deserialize_with_new_optional_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_optional_field_works_with_default_file_recorder() { deserialize_with_removed_optional_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_constant_field_works_with_default_file_recorder() { deserialize_with_new_constant_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_constant_field_works_with_default_file_recorder() { deserialize_with_removed_constant_field( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_field_order_works_with_default_file_recorder() { deserialize_with_new_field_order( "default", DefaultFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_optional_field_works_with_pretty_json() { deserialize_with_new_optional_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_optional_field_works_with_pretty_json() { deserialize_with_removed_optional_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_constant_field_works_with_pretty_json() { deserialize_with_new_constant_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_removed_constant_field_works_with_pretty_json() { deserialize_with_removed_constant_field( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_field_order_works_with_pretty_json() { deserialize_with_new_field_order( "pretty-json", PrettyJsonFileRecorder::::new(), ) .unwrap(); } #[test] #[should_panic] fn deserialize_with_new_optional_field_doesnt_works_with_bin_file_recorder() { deserialize_with_new_optional_field("bin", BinFileRecorder::::new()) .unwrap(); } #[test] fn deserialize_with_removed_optional_field_works_with_bin_file_recorder() { deserialize_with_removed_optional_field( "bin", BinFileRecorder::::new(), ) .unwrap(); } #[test] fn deserialize_with_new_constant_field_works_with_bin_file_recorder() { deserialize_with_new_constant_field("bin", BinFileRecorder::::new()) .unwrap(); } #[test] fn deserialize_with_removed_constant_field_works_with_bin_file_recorder() { deserialize_with_removed_constant_field( "bin", BinFileRecorder::::new(), ) .unwrap(); } #[test] #[should_panic] fn deserialize_with_new_field_order_works_with_bin_file_recorder() { deserialize_with_new_field_order("bin", BinFileRecorder::::new()) .unwrap(); } #[inline(always)] fn file_path(filename: String) -> PathBuf { std::env::temp_dir().join(filename) } #[test] fn test_tensor_serde() { let tensor: burn_tensor::Tensor = burn_tensor::Tensor::ones([1], &NdArrayDevice::default()); let encoded = serde_json::to_string(&tensor).unwrap(); let decoded: burn_tensor::Tensor = serde_json::from_str(&encoded).unwrap(); assert_eq!(tensor.into_data(), decoded.into_data()); } fn deserialize_with_new_optional_field(name: &str, recorder: R) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_new_optional_field-{name}")); let model = Model { single_const: 32.0, linear1: nn::LinearConfig::new(20, 20).init::(&device), array_const: [2, 2], linear2: nn::LinearConfig::new(20, 20).init::(&device), }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_removed_optional_field( name: &str, recorder: R, ) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_removed_optional_field-{name}")); let model = ModelNewOptionalField { single_const: 32.0, linear1: nn::LinearConfig::new(20, 20).init::(&device), array_const: [2, 2], linear2: nn::LinearConfig::new(20, 20).init::(&device), new_field: None, }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_new_constant_field(name: &str, recorder: R) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_new_constant_field-{name}")); let model = Model { single_const: 32.0, array_const: [2, 2], linear1: nn::LinearConfig::new(20, 20).init::(&device), linear2: nn::LinearConfig::new(20, 20).init::(&device), }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_removed_constant_field( name: &str, recorder: R, ) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_removed_constant_field-{name}")); let model = ModelNewConstantField { single_const: 32.0, array_const: [2, 2], linear1: nn::LinearConfig::new(20, 20).init::(&device), linear2: nn::LinearConfig::new(20, 20).init::(&device), new_field: 0, }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } fn deserialize_with_new_field_order(name: &str, recorder: R) -> Result<(), RecorderError> where R: FileRecorder, { let device = Default::default(); let file_path: PathBuf = file_path(format!("deserialize_with_new_field_order-{name}")); let model = Model { array_const: [2, 2], single_const: 32.0, linear1: nn::LinearConfig::new(20, 20).init::(&device), linear2: nn::LinearConfig::new(20, 20).init::(&device), }; recorder .record(model.into_record(), file_path.clone()) .unwrap(); let result = recorder.load::>(file_path.clone(), &device); std::fs::remove_file(file_path).ok(); result?; Ok(()) } }