use metatensor::{TensorBlock, TensorMap, Labels}; use ndarray::ArrayD; mod utils; use utils::example_labels; #[test] fn one_component() { let mut block = TensorBlock::new( ArrayD::from_elem(vec![3, 2, 3], 1.0), &example_labels(vec!["samples"], vec![[0], [1], [2]]), &[example_labels(vec!["components"], vec![[0], [1]])], &example_labels(vec!["properties"], vec![[0], [1], [2]]), ).unwrap(); let gradient = TensorBlock::new( ArrayD::from_elem(vec![2, 2, 3], 11.0), &example_labels(vec!["sample", "parameter"], vec![[0, 2], [1, 2]]), &[example_labels(vec!["components"], vec![[0], [1]])], &example_labels(vec!["properties"], vec![[0], [1], [2]]), ).unwrap(); block.add_gradient("parameter", gradient).unwrap(); let tensor = TensorMap::new(Labels::single(), vec![block]).unwrap(); let tensor = tensor.components_to_properties(&["components"]).unwrap(); let block = tensor.block_by_id(0); assert_eq!(block.samples().names(), ["samples"]); assert_eq!(block.samples().count(), 3); assert_eq!(block.samples()[0], [0]); assert_eq!(block.samples()[1], [1]); assert_eq!(block.samples()[2], [2]); assert_eq!(block.components().len(), 0); assert_eq!(block.properties().names(), ["components", "properties"]); assert_eq!(block.properties().count(), 6); assert_eq!(block.properties()[0], [0, 0]); assert_eq!(block.properties()[1], [0, 1]); assert_eq!(block.properties()[2], [0, 2]); assert_eq!(block.properties()[3], [1, 0]); assert_eq!(block.properties()[4], [1, 1]); assert_eq!(block.properties()[5], [1, 2]); assert_eq!(block.values().as_array(), ArrayD::from_elem(vec![3, 6], 1.0)); let gradient = block.gradient("parameter").unwrap(); assert_eq!(gradient.samples().names(), ["sample", "parameter"]); assert_eq!(gradient.samples().count(), 2); assert_eq!(gradient.samples()[0], [0, 2]); assert_eq!(gradient.samples()[1], [1, 2]); assert_eq!(gradient.values().as_array(), ArrayD::from_elem(vec![2, 6], 11.0)); } #[test] fn multiple_components() { let data = ArrayD::from_shape_vec(vec![2, 2, 3, 2], vec![ 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0, -1.0, 1.0, -2.0, 2.0, -3.0, 3.0, -4.0, 4.0, -5.0, 5.0, -6.0, 6.0, ]).unwrap(); let components = [ example_labels(vec!["component_1"], vec![[0], [1]]), example_labels(vec!["component_2"], vec![[0], [1], [2]]), ]; let properties = example_labels(vec!["properties"], vec![[0], [1]]); let mut block = TensorBlock::new( data, &example_labels(vec!["samples"], vec![[0], [1]]), &components, &properties, ).unwrap(); let gradient = TensorBlock::new( ArrayD::from_elem(vec![3, 2, 3, 2], 11.0), &example_labels(vec!["sample", "parameter"], vec![[0, 2], [0, 3], [1, 2]]), &components, &properties, ).unwrap(); block.add_gradient("parameter", gradient).unwrap(); let tensor = TensorMap::new(Labels::single(), vec![block]).unwrap(); let tensor = tensor.components_to_properties(&["component_1"]).unwrap(); let block = tensor.block_by_id(0); assert_eq!(block.samples().names(), ["samples"]); assert_eq!(block.samples().count(), 2); assert_eq!(block.samples()[0], [0]); assert_eq!(block.samples()[1], [1]); assert_eq!(block.components().len(), 1); assert_eq!(block.components()[0].names(), ["component_2"]); assert_eq!(block.components()[0].count(), 3); assert_eq!(block.components()[0][0], [0]); assert_eq!(block.components()[0][1], [1]); assert_eq!(block.components()[0][2], [2]); assert_eq!(block.properties().names(), ["component_1", "properties"]); assert_eq!(block.properties().count(), 4); assert_eq!(block.properties()[0], [0, 0]); assert_eq!(block.properties()[1], [0, 1]); assert_eq!(block.properties()[2], [1, 0]); assert_eq!(block.properties()[3], [1, 1]); let expected = ArrayD::from_shape_vec(vec![2, 3, 4], vec![ 1.0, 1.0, 4.0, 4.0, 2.0, 2.0, 5.0, 5.0, 3.0, 3.0, 6.0, 6.0, -1.0, 1.0, -4.0, 4.0, -2.0, 2.0, -5.0, 5.0, -3.0, 3.0, -6.0, 6.0, ]).unwrap(); assert_eq!(block.values().as_array(), expected); let gradient = block.gradient("parameter").unwrap(); assert_eq!(gradient.samples().names(), ["sample", "parameter"]); assert_eq!(gradient.samples().count(), 3); assert_eq!(gradient.samples()[0], [0, 2]); assert_eq!(gradient.samples()[1], [0, 3]); assert_eq!(gradient.samples()[2], [1, 2]); assert_eq!(gradient.values().as_array(), ArrayD::from_elem(vec![3, 3, 4], 11.0)); }