extern crate easy_ml; #[cfg(test)] mod tensors { use easy_ml::tensors::Tensor; #[test] #[rustfmt::skip] fn from_fn_test() { let tensor = Tensor::from_fn([("a", 3), ("b", 2)], |[i, j]| format!("{:?}x{:?}", i, j)); assert_eq!( tensor, Tensor::from( [("a", 3), ("b", 2)], vec![ "0x0", "0x1", "1x0", "1x1", "2x0", "2x1" ] ).map(|str| str.to_string()) ) } #[test] fn indexing_test() { let tensor = Tensor::from([("x", 2), ("y", 2)], vec![1, 2, 3, 4]); let xy = tensor.index_by(["x", "y"]); let yx = tensor.index_by(["y", "x"]); assert_eq!(xy.get([0, 0]), 1); assert_eq!(xy.get([0, 1]), 2); assert_eq!(xy.get([1, 0]), 3); assert_eq!(xy.get([1, 1]), 4); assert_eq!(yx.get([0, 0]), 1); assert_eq!(yx.get([0, 1]), 3); assert_eq!(yx.get([1, 0]), 2); assert_eq!(yx.get([1, 1]), 4); use easy_ml::tensors::views::{DataLayout, TensorRef}; assert_eq!(xy.data_layout(), DataLayout::Linear(["x", "y"])); assert_eq!(yx.data_layout(), DataLayout::Linear(["x", "y"])); } #[test] fn higher_dimensional_indexing_test() { let tensor = Tensor::from([("a", 3), ("b", 3), ("c", 3)], (0..27).collect()); let tensor = tensor.map_with_index(|index, _| index); let abc = tensor.index_by(["a", "b", "c"]); assert_eq!(abc.shape(), [("a", 3), ("b", 3), ("c", 3)]); // Our first index is for a, our second index is for b, and our third index is for c assert_eq!(abc.get([0, 0, 0]), [0, 0, 0]); assert_eq!(abc.get([0, 0, 1]), [0, 0, 1]); assert_eq!(abc.get([0, 0, 2]), [0, 0, 2]); assert_eq!(abc.get([0, 1, 0]), [0, 1, 0]); assert_eq!(abc.get([0, 1, 1]), [0, 1, 1]); assert_eq!(abc.get([0, 1, 2]), [0, 1, 2]); assert_eq!(abc.get([0, 2, 0]), [0, 2, 0]); assert_eq!(abc.get([0, 2, 1]), [0, 2, 1]); assert_eq!(abc.get([0, 2, 2]), [0, 2, 2]); assert_eq!(abc.get([1, 0, 0]), [1, 0, 0]); assert_eq!(abc.get([1, 0, 1]), [1, 0, 1]); assert_eq!(abc.get([1, 0, 2]), [1, 0, 2]); assert_eq!(abc.get([1, 1, 0]), [1, 1, 0]); assert_eq!(abc.get([1, 1, 1]), [1, 1, 1]); assert_eq!(abc.get([1, 1, 2]), [1, 1, 2]); assert_eq!(abc.get([1, 2, 0]), [1, 2, 0]); assert_eq!(abc.get([1, 2, 1]), [1, 2, 1]); assert_eq!(abc.get([1, 2, 2]), [1, 2, 2]); assert_eq!(abc.get([2, 0, 0]), [2, 0, 0]); assert_eq!(abc.get([2, 0, 1]), [2, 0, 1]); assert_eq!(abc.get([2, 0, 2]), [2, 0, 2]); assert_eq!(abc.get([2, 1, 0]), [2, 1, 0]); assert_eq!(abc.get([2, 1, 1]), [2, 1, 1]); assert_eq!(abc.get([2, 1, 2]), [2, 1, 2]); assert_eq!(abc.get([2, 2, 0]), [2, 2, 0]); assert_eq!(abc.get([2, 2, 1]), [2, 2, 1]); assert_eq!(abc.get([2, 2, 2]), [2, 2, 2]); let cba = tensor.index_by(["c", "b", "a"]); assert_eq!(cba.shape(), [("c", 3), ("b", 3), ("a", 3)]); // Our first index is for c, our second index is for b, and our third index is for a assert_eq!(cba.get([0, 0, 0]), [0, 0, 0]); assert_eq!(cba.get([0, 0, 1]), [1, 0, 0]); assert_eq!(cba.get([0, 0, 2]), [2, 0, 0]); assert_eq!(cba.get([0, 1, 0]), [0, 1, 0]); assert_eq!(cba.get([0, 1, 1]), [1, 1, 0]); assert_eq!(cba.get([0, 1, 2]), [2, 1, 0]); assert_eq!(cba.get([0, 2, 0]), [0, 2, 0]); assert_eq!(cba.get([0, 2, 1]), [1, 2, 0]); assert_eq!(cba.get([0, 2, 2]), [2, 2, 0]); assert_eq!(cba.get([1, 0, 0]), [0, 0, 1]); assert_eq!(cba.get([1, 0, 1]), [1, 0, 1]); assert_eq!(cba.get([1, 0, 2]), [2, 0, 1]); assert_eq!(cba.get([1, 1, 0]), [0, 1, 1]); assert_eq!(cba.get([1, 1, 1]), [1, 1, 1]); assert_eq!(cba.get([1, 1, 2]), [2, 1, 1]); assert_eq!(cba.get([1, 2, 0]), [0, 2, 1]); assert_eq!(cba.get([1, 2, 1]), [1, 2, 1]); assert_eq!(cba.get([1, 2, 2]), [2, 2, 1]); assert_eq!(cba.get([2, 0, 0]), [0, 0, 2]); assert_eq!(cba.get([2, 0, 1]), [1, 0, 2]); assert_eq!(cba.get([2, 0, 2]), [2, 0, 2]); assert_eq!(cba.get([2, 1, 0]), [0, 1, 2]); assert_eq!(cba.get([2, 1, 1]), [1, 1, 2]); assert_eq!(cba.get([2, 1, 2]), [2, 1, 2]); assert_eq!(cba.get([2, 2, 0]), [0, 2, 2]); assert_eq!(cba.get([2, 2, 1]), [1, 2, 2]); assert_eq!(cba.get([2, 2, 2]), [2, 2, 2]); let cab = tensor.index_by(["c", "a", "b"]); assert_eq!(cab.shape(), [("c", 3), ("a", 3), ("b", 3)]); // Our first index is for c, our second index is for a, and our third index is for b assert_eq!(cab.get([0, 0, 0]), [0, 0, 0]); assert_eq!(cab.get([0, 0, 1]), [0, 1, 0]); assert_eq!(cab.get([0, 0, 2]), [0, 2, 0]); assert_eq!(cab.get([0, 1, 0]), [1, 0, 0]); assert_eq!(cab.get([0, 1, 1]), [1, 1, 0]); assert_eq!(cab.get([0, 1, 2]), [1, 2, 0]); assert_eq!(cab.get([0, 2, 0]), [2, 0, 0]); assert_eq!(cab.get([0, 2, 1]), [2, 1, 0]); assert_eq!(cab.get([0, 2, 2]), [2, 2, 0]); assert_eq!(cab.get([1, 0, 0]), [0, 0, 1]); assert_eq!(cab.get([1, 0, 1]), [0, 1, 1]); assert_eq!(cab.get([1, 0, 2]), [0, 2, 1]); assert_eq!(cab.get([1, 1, 0]), [1, 0, 1]); assert_eq!(cab.get([1, 1, 1]), [1, 1, 1]); assert_eq!(cab.get([1, 1, 2]), [1, 2, 1]); assert_eq!(cab.get([1, 2, 0]), [2, 0, 1]); assert_eq!(cab.get([1, 2, 1]), [2, 1, 1]); assert_eq!(cab.get([1, 2, 2]), [2, 2, 1]); assert_eq!(cab.get([2, 0, 0]), [0, 0, 2]); assert_eq!(cab.get([2, 0, 1]), [0, 1, 2]); assert_eq!(cab.get([2, 0, 2]), [0, 2, 2]); assert_eq!(cab.get([2, 1, 0]), [1, 0, 2]); assert_eq!(cab.get([2, 1, 1]), [1, 1, 2]); assert_eq!(cab.get([2, 1, 2]), [1, 2, 2]); assert_eq!(cab.get([2, 2, 0]), [2, 0, 2]); assert_eq!(cab.get([2, 2, 1]), [2, 1, 2]); assert_eq!(cab.get([2, 2, 2]), [2, 2, 2]); let bca = tensor.index_by(["b", "c", "a"]); assert_eq!(bca.shape(), [("b", 3), ("c", 3), ("a", 3)]); // Our first index is for b, our second index is for c, and our third index is for a assert_eq!(bca.get([0, 0, 0]), [0, 0, 0]); assert_eq!(bca.get([0, 0, 1]), [1, 0, 0]); assert_eq!(bca.get([0, 0, 2]), [2, 0, 0]); assert_eq!(bca.get([0, 1, 0]), [0, 0, 1]); assert_eq!(bca.get([0, 1, 1]), [1, 0, 1]); assert_eq!(bca.get([0, 1, 2]), [2, 0, 1]); assert_eq!(bca.get([0, 2, 0]), [0, 0, 2]); assert_eq!(bca.get([0, 2, 1]), [1, 0, 2]); assert_eq!(bca.get([0, 2, 2]), [2, 0, 2]); assert_eq!(bca.get([1, 0, 0]), [0, 1, 0]); assert_eq!(bca.get([1, 0, 1]), [1, 1, 0]); assert_eq!(bca.get([1, 0, 2]), [2, 1, 0]); assert_eq!(bca.get([1, 1, 0]), [0, 1, 1]); assert_eq!(bca.get([1, 1, 1]), [1, 1, 1]); assert_eq!(bca.get([1, 1, 2]), [2, 1, 1]); assert_eq!(bca.get([1, 2, 0]), [0, 1, 2]); assert_eq!(bca.get([1, 2, 1]), [1, 1, 2]); assert_eq!(bca.get([1, 2, 2]), [2, 1, 2]); assert_eq!(bca.get([2, 0, 0]), [0, 2, 0]); assert_eq!(bca.get([2, 0, 1]), [1, 2, 0]); assert_eq!(bca.get([2, 0, 2]), [2, 2, 0]); assert_eq!(bca.get([2, 1, 0]), [0, 2, 1]); assert_eq!(bca.get([2, 1, 1]), [1, 2, 1]); assert_eq!(bca.get([2, 1, 2]), [2, 2, 1]); assert_eq!(bca.get([2, 2, 0]), [0, 2, 2]); assert_eq!(bca.get([2, 2, 1]), [1, 2, 2]); assert_eq!(bca.get([2, 2, 2]), [2, 2, 2]); use easy_ml::tensors::views::{DataLayout, TensorRef}; assert_eq!(abc.data_layout(), DataLayout::Linear(["a", "b", "c"])); assert_eq!(cba.data_layout(), DataLayout::Linear(["a", "b", "c"])); assert_eq!(cab.data_layout(), DataLayout::Linear(["a", "b", "c"])); assert_eq!(bca.data_layout(), DataLayout::Linear(["a", "b", "c"])); } #[test] #[should_panic] fn repeated_name() { Tensor::from([("x", 2), ("x", 2)], vec![1, 2, 3, 4]); } #[test] #[should_panic] fn wrong_size() { Tensor::from([("x", 2), ("y", 3)], vec![1, 2, 3, 4]); } #[test] #[should_panic] fn bad_indexing() { let tensor = Tensor::from([("x", 2), ("y", 2)], vec![1, 2, 3, 4]); tensor.index_by(["x", "x"]); } #[test] #[rustfmt::skip] fn transpose_more_dimensions() { let tensor = Tensor::from( [("batch", 2), ("y", 10), ("x", 10), ("color", 1)], vec![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]); let transposed = tensor.transpose(["batch", "x", "y", "color"]); assert_eq!( transposed, Tensor::from([("batch", 2), ("y", 10), ("x", 10), ("color", 1)], vec![ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ]) ); } #[test] fn check_iterators() { #[rustfmt::skip] let tensor = Tensor::from([("row", 3), ("column", 2)], vec![ 1, 2, 3, 4, 5, 6 ]); let mut row_column_iterator = tensor.iter_reference(); assert_eq!(row_column_iterator.next(), Some(&1)); assert_eq!(row_column_iterator.next(), Some(&2)); assert_eq!(row_column_iterator.next(), Some(&3)); assert_eq!(row_column_iterator.next(), Some(&4)); assert_eq!(row_column_iterator.next(), Some(&5)); assert_eq!(row_column_iterator.next(), Some(&6)); assert_eq!(row_column_iterator.next(), None); } #[test] fn check_iterators_with_index() { #[rustfmt::skip] let tensor = Tensor::from([("row", 3), ("column", 2)], vec![ 1, 2, 3, 4, 5, 6 ]); let row_column = tensor.index(); let mut iterator = row_column.iter_reference().with_index(); assert_eq!(iterator.next(), Some(([0, 0], &1))); assert_eq!(iterator.next(), Some(([0, 1], &2))); assert_eq!(iterator.next(), Some(([1, 0], &3))); assert_eq!(iterator.next(), Some(([1, 1], &4))); assert_eq!(iterator.next(), Some(([2, 0], &5))); assert_eq!(iterator.next(), Some(([2, 1], &6))); assert_eq!(iterator.next(), None); } #[test] fn check_transposition() { let mut tensor = Tensor::from([("row", 4), ("column", 1)], vec![1, 2, 3, 4]); tensor.transpose_mut(["column", "row"]); assert_eq!( tensor, Tensor::from([("row", 1), ("column", 4)], vec![1, 2, 3, 4]) ); let mut tensor = Tensor::from([("row", 1), ("column", 4)], vec![1, 2, 3, 4]); tensor.transpose_mut(["column", "row"]); assert_eq!( tensor, Tensor::from([("row", 4), ("column", 1)], vec![1, 2, 3, 4]) ); #[rustfmt::skip] let mut tensor = Tensor::from([("row", 3), ("column", 3)], vec![ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]); tensor.transpose_mut(["column", "row"]); #[rustfmt::skip] assert_eq!( tensor, Tensor::from( [("row", 3), ("column", 3)], vec![ 1, 4, 7, 2, 5, 8, 3, 6, 9 ] ) ); #[rustfmt::skip] let mut tensor = Tensor::from([("r", 2), ("c", 3)], vec![ 1, 2, 3, 4, 5, 6 ]); tensor.transpose_mut(["c", "r"]); #[rustfmt::skip] assert_eq!( tensor, Tensor::from([("r", 3), ("c", 2)], vec![ 1, 4, 2, 5, 3, 6 ]) ); #[rustfmt::skip] let mut tensor = Tensor::from([("a", 3), ("b", 2)], vec![ 1, 2, 3, 4, 5, 6 ]); tensor.transpose_mut(["b", "a"]); #[rustfmt::skip] assert_eq!( tensor, Tensor::from([("a", 2), ("b", 3)], vec![ 1, 3, 5, 2, 4, 6 ]) ); #[rustfmt::skip] let tensor = Tensor::from([("row", 3), ("column", 3)], vec![ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]); #[rustfmt::skip] assert_eq!( tensor.transpose(["column", "row"]), Tensor::from( [("row", 3), ("column", 3)], vec![ 1, 4, 7, 2, 5, 8, 3, 6, 9 ] ) ); } #[test] fn check_reorder() { let mut tensor = Tensor::from([("row", 4), ("column", 1)], vec![1, 2, 3, 4]); tensor.reorder_mut(["column", "row"]); assert_eq!( tensor, Tensor::from([("column", 1), ("row", 4)], vec![1, 2, 3, 4]) ); let mut tensor = Tensor::from([("row", 1), ("column", 4)], vec![1, 2, 3, 4]); tensor.reorder_mut(["column", "row"]); assert_eq!( tensor, Tensor::from([("column", 4), ("row", 1)], vec![1, 2, 3, 4]) ); #[rustfmt::skip] let mut tensor = Tensor::from([("row", 3), ("column", 3)], vec![ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]); tensor.reorder_mut(["column", "row"]); assert_eq!( tensor, Tensor::from( [("column", 3), ("row", 3)], vec![1, 4, 7, 2, 5, 8, 3, 6, 9,] ) ); #[rustfmt::skip] let mut tensor = Tensor::from([("r", 2), ("c", 3)], vec![ 1, 2, 3, 4, 5, 6 ]); tensor.reorder_mut(["c", "r"]); assert_eq!( tensor, Tensor::from([("c", 3), ("r", 2)], vec![1, 4, 2, 5, 3, 6,]) ); #[rustfmt::skip] let mut tensor = Tensor::from([("a", 3), ("b", 2)], vec![ 1, 2, 3, 4, 5, 6 ]); tensor.reorder_mut(["b", "a"]); assert_eq!( tensor, Tensor::from([("b", 2), ("a", 3)], vec![1, 3, 5, 2, 4, 6,]) ); #[rustfmt::skip] let tensor = Tensor::from([("row", 3), ("column", 3)], vec![ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]); assert_eq!( tensor.reorder(["column", "row"]), Tensor::from( [("column", 3), ("row", 3)], vec![1, 4, 7, 2, 5, 8, 3, 6, 9,] ) ); } #[test] fn test_reshaping() { let tensor = Tensor::from([("everything", 20)], (0..20).collect()); let mut five_by_four = tensor.reshape_owned([("fives", 5), ("fours", 4)]); #[rustfmt::skip] assert_eq!( Tensor::from([("fives", 5), ("fours", 4)], vec![ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ]), five_by_four ); five_by_four.reshape_mut([("twos", 2), ("tens", 10)]); #[rustfmt::skip] assert_eq!( Tensor::from([("twos", 2), ("tens", 10)], vec![ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ]), five_by_four ); let flattened = five_by_four.reshape_owned([("data", 20)]); assert_eq!(flattened, Tensor::from([("data", 20)], (0..20).collect())); } #[test] #[should_panic] fn invalid_reshape() { let mut square = Tensor::from([("r", 2), ("c", 2)], (0..4).collect()); square.reshape_mut([("not", 3), ("square", 1)]); } #[test] fn check_data_layout_tensor() { use easy_ml::tensors::views::{DataLayout, TensorRef}; let tensor = Tensor::from([("b", 3), ("r", 3), ("c", 3)], (0..27).collect()); assert_eq!(tensor.data_layout(), DataLayout::Linear(["b", "r", "c"])); let tensor = Tensor::from([("r", 2), ("c", 2)], (0..4).collect()); assert_eq!(tensor.data_layout(), DataLayout::Linear(["r", "c"])); let tensor = Tensor::from([("a", 3)], (0..3).collect()); assert_eq!(tensor.data_layout(), DataLayout::Linear(["a"])); } #[test] fn check_data_layout_non_linear_tensor_views() { use easy_ml::tensors::views::{DataLayout, TensorRef}; let tensor = Tensor::from([("b", 3), ("r", 3), ("c", 3)], (0..27).collect()); assert_eq!(tensor.data_layout(), DataLayout::Linear(["b", "r", "c"])); assert_eq!( tensor .range([("b", 0..2)]) .unwrap() .source_ref() .data_layout(), DataLayout::NonLinear ); assert_eq!( tensor .mask([("c", 0..2)]) .unwrap() .source_ref() .data_layout(), DataLayout::NonLinear ); assert_eq!( tensor.select([("b", 1)]).source_ref().data_layout(), DataLayout::NonLinear ); assert_eq!( tensor.expand([(2, "x")]).source_ref().data_layout(), DataLayout::NonLinear ); } #[test] fn check_data_layout_tensor_access() { use easy_ml::tensors::indexing::TensorAccess; use easy_ml::tensors::views::{DataLayout, TensorRef, TensorRename, TensorView}; let tensor = Tensor::from([("b", 3), ("r", 3), ("c", 3)], (0..27).collect()); assert_eq!(tensor.data_layout(), DataLayout::Linear(["b", "r", "c"])); assert_eq!( tensor.index_by(["b", "r", "c"]).data_layout(), DataLayout::Linear(["b", "r", "c"]) ); assert_eq!( tensor.index_by(["c", "r", "b"]).data_layout(), DataLayout::Linear(["b", "r", "c"]) ); assert_eq!( tensor.index_by(["r", "c", "b"]).data_layout(), DataLayout::Linear(["b", "r", "c"]) ); assert_eq!( tensor.index_by(["r", "b", "c"]).data_layout(), DataLayout::Linear(["b", "r", "c"]) ); assert_eq!( tensor.index_by(["c", "b", "r"]).data_layout(), DataLayout::Linear(["b", "r", "c"]) ); // Each time we transpose we expect the data layout we get back to be correct, which // we can verify by using it as the index order. If the data layout is correct then // returning to big endian order means we iterate through the tensor as the 0..27 it was // defined with. let transposed = tensor.transpose_view(["b", "r", "c"]); assert_eq!( transposed.source_ref().data_layout(), DataLayout::Linear(["b", "r", "c"]) ); assert_eq!( (0..27).collect::>(), transposed .index_by(["b", "r", "c"]) .iter() .collect::>() ); // We can avoid manually passing the linear data layout if we use from_memory_order assert_eq!( (0..27).collect::>(), TensorAccess::from_memory_order(transposed.source_ref()) .unwrap() .iter() .collect::>() ); // Alternative way to 'transpose', should match TensorTranspose exactly let also_transposed = TensorView::from(TensorRename::from( tensor.index_by(["b", "r", "c"]), ["b", "r", "c"], )); assert_eq!(transposed, also_transposed); assert_eq!( also_transposed.source_ref().data_layout(), DataLayout::Linear(["b", "r", "c"]) ); let transposed = tensor.transpose_view(["c", "r", "b"]); assert_eq!( transposed.source_ref().data_layout(), DataLayout::Linear(["c", "r", "b"]) ); assert_eq!( (0..27).collect::>(), transposed .index_by(["c", "r", "b"]) .iter() .collect::>() ); assert_eq!( (0..27).collect::>(), TensorAccess::from_memory_order(transposed.source_ref()) .unwrap() .iter() .collect::>() ); let also_transposed = TensorView::from(TensorRename::from( tensor.index_by(["c", "r", "b"]), ["b", "r", "c"], )); assert_eq!(transposed, also_transposed); assert_eq!( also_transposed.source_ref().data_layout(), DataLayout::Linear(["c", "r", "b"]) ); let transposed = tensor.transpose_view(["r", "c", "b"]); assert_eq!( transposed.source_ref().data_layout(), DataLayout::Linear(["c", "b", "r"]) ); assert_eq!( (0..27).collect::>(), transposed .index_by(["c", "b", "r"]) .iter() .collect::>() ); assert_eq!( (0..27).collect::>(), TensorAccess::from_memory_order(transposed.source_ref()) .unwrap() .iter() .collect::>() ); let also_transposed = TensorView::from(TensorRename::from( tensor.index_by(["r", "c", "b"]), ["b", "r", "c"], )); assert_eq!(transposed, also_transposed); assert_eq!( also_transposed.source_ref().data_layout(), DataLayout::Linear(["c", "b", "r"]) ); let transposed = tensor.transpose_view(["r", "b", "c"]); assert_eq!( transposed.source_ref().data_layout(), DataLayout::Linear(["r", "b", "c"]) ); assert_eq!( (0..27).collect::>(), transposed .index_by(["r", "b", "c"]) .iter() .collect::>() ); assert_eq!( (0..27).collect::>(), TensorAccess::from_memory_order(transposed.source_ref()) .unwrap() .iter() .collect::>() ); let also_transposed = TensorView::from(TensorRename::from( tensor.index_by(["r", "b", "c"]), ["b", "r", "c"], )); assert_eq!(transposed, also_transposed); assert_eq!( also_transposed.source_ref().data_layout(), DataLayout::Linear(["r", "b", "c"]) ); let transposed = tensor.transpose_view(["c", "b", "r"]); assert_eq!( transposed.source_ref().data_layout(), DataLayout::Linear(["r", "c", "b"]) ); assert_eq!( (0..27).collect::>(), transposed .index_by(["r", "c", "b"]) .iter() .collect::>() ); assert_eq!( (0..27).collect::>(), TensorAccess::from_memory_order(transposed.source_ref()) .unwrap() .iter() .collect::>() ); let also_transposed = TensorView::from(TensorRename::from( tensor.index_by(["c", "b", "r"]), ["b", "r", "c"], )); assert_eq!(transposed, also_transposed); assert_eq!( also_transposed.source_ref().data_layout(), DataLayout::Linear(["r", "c", "b"]) ); } #[test] fn check_data_layout_linear_tensor_views() { use easy_ml::tensors::views::{DataLayout, TensorRef}; let tensor = Tensor::from([("b", 3), ("r", 3), ("c", 3)], (0..27).collect()); assert_eq!(tensor.data_layout(), DataLayout::Linear(["b", "r", "c"])); assert_eq!( tensor .rename_view(["a", "q", "b"]) .source_ref() .data_layout(), DataLayout::Linear(["a", "q", "b"]) ); } #[test] fn display_and_indexing_for_reordering() { use easy_ml::tensors::views::TensorView; let tensor = Tensor::from([("a", 2), ("b", 3), ("c", 4)], (0..(2 * 3 * 4)).collect()); assert_eq!( tensor.iter().collect::>(), (0..(2 * 3 * 4)).collect::>() ); assert_eq!( r#"D = 3 ("a", 2), ("b", 3), ("c", 4) [ 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 16, 17, 18, 19 20, 21, 22, 23 ]"#, tensor.to_string(), ); let reordered = tensor.index_by(["b", "c", "a"]); // reordering the 3D tensor should yield a tensor that still displays with the leftmost // dimension as the largest group, the second dimension as rows and the final dimension as // columns assert_eq!( reordered.iter().collect::>(), vec![ 0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23 ] ); assert_eq!( r#"D = 3 ("b", 3), ("c", 4), ("a", 2) [ 0, 12 1, 13 2, 14 3, 15 4, 16 5, 17 6, 18 7, 19 8, 20 9, 21 10, 22 11, 23 ] Data Layout = Linear(["a", "b", "c"])"#, reordered.to_string(), ); // To transpose our way back, make biggest dimension a, since as from above that's a // stride of 12. Make next dimension b. since that's a stride of 4, then make smallest // dimension c since that's a stride of 1 (this aligns with data layout too). let transposed = TensorView::from(reordered).transpose(["a", "b", "c"]); assert_eq!( transposed.iter().collect::>(), (0..(2 * 3 * 4)).collect::>() ); assert_eq!( r#"D = 3 ("b", 2), ("c", 3), ("a", 4) [ 0, 1, 2, 3 4, 5, 6, 7 8, 9, 10, 11 12, 13, 14, 15 16, 17, 18, 19 20, 21, 22, 23 ]"#, transposed.to_string(), ); } #[test] fn test_identity_constructor() { let identity = Tensor::diagonal([("a", 3), ("b", 3)], 1.0); #[rustfmt::skip] assert_eq!( identity, Tensor::from([("a", 3), ("b", 3)], vec![ 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0 ]) ); } #[test] fn test_owned_iterator_of_mut_source() { use easy_ml::tensors::indexing::TensorOwnedIterator; let mut tensor = Tensor::from_fn([("x", 2), ("y", 2)], |[x, y]| x + y); let mut_tensor = &mut tensor; let owned_iter = TensorOwnedIterator::from(mut_tensor); let drained = owned_iter.collect::>(); assert_eq!(drained, vec![0, 1, 1, 2]); // original tensor is now drained of its values so should be set to 0 as that's the // default value substituted. assert_eq!(tensor, Tensor::empty([("x", 2), ("y", 2)], 0)); } }