use rstest::rstest; use arr_rs::prelude::*; #[rstest( arrs, axis, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], None, array!(i32, [1, 2, 3, 4, 5, 6])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [5, 6])], None, array!(i32, [1, 2, 3, 4, 5, 6])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6]])], Some(0), array!(i32, [[1, 2], [3, 4], [5, 6]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6]]), array!(i32, [[7, 8]])], Some(0), array!(i32, [[1, 2], [3, 4], [5, 6], [7, 8]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5], [6]])], Some(1), array!(i32, [[1, 2, 5], [3, 4, 6]])), case(vec![array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]), array!(i32, [[[1, 2], [3, 4]]])], Some(0), array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]])), case(vec![array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]), array!(i32, [[[1, 2], [3, 4]]]), array!(i32, [[[5, 6], [7, 8]]])], Some(0), array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]], [[5, 6], [7, 8]]])), case(vec![array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]), array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]]])], Some(1), array!(i32, [[[1, 2], [3, 4], [1, 2], [3, 4]], [[1, 2], [3, 4], [1, 2], [3, 4]]])), case(vec![array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]), array!(i32, [[[1, 2], [3, 4]], [[1, 2], [3, 4]]])], Some(2), array!(i32, [[[1, 2, 1, 2], [3, 4, 3, 4]], [[1, 2, 1, 2], [3, 4, 3, 4]]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [5, 6])], Some(0), Err(ArrayError::ConcatenateShapeMismatch)), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6]])], Some(1), Err(ArrayError::ConcatenateShapeMismatch)), )] fn test_concatenate(arrs: Vec, ArrayError>>, axis: Option, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::concatenate(arrs, axis)) } #[rstest( arrs, axis, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], None, array!(i32, [[1, 2, 3], [4, 5, 6]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6], [7, 8]])], None, array!(i32, [[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), case(vec![array!(i32, [[[1, 2], [3, 4]]]), array!(i32, [[[5, 6], [7, 8]]])], Some(0), array!(i32, [[[[1, 2], [3, 4]]], [[[5, 6], [7, 8]]]])), case(vec![array!(i32, [[[1, 2], [3, 4]]]), array!(i32, [[[5, 6], [7, 8]]]), array!(i32, [[[9, 10], [11, 12]]])], Some(0), array!(i32, [[[[1, 2], [3, 4]]], [[[5, 6], [7, 8]]], [[[9, 10], [11, 12]]]])), case(vec![array!(i32, [[[1, 2], [3, 4]]]), array!(i32, [[[5, 6], [7, 8]]])], Some(1), array!(i32, [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]])), case(vec![array!(i32, [[[1, 2], [3, 4]]]), array!(i32, [[[5, 6], [7, 8]]])], Some(2), array!(i32, [[[[1, 2], [5, 6]], [[3, 4], [7, 8]]]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6]])], None, Err(ArrayError::ParameterError { param: "arrs", message: "all input arrays must have the same shape" })), )] fn test_stack(arrs: Vec, ArrayError>>, axis: Option, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::stack(arrs, axis)) } #[rstest( arrs, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], array!(i32, [[1, 2, 3], [4, 5, 6]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6], [7, 8], [9, 10]])], array!(i32, [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])), case(vec![array!(i32, [[1]]), array!(i32, [[2]]), array!(i32, [[3]])], array!(i32, [[1], [2], [3]])), case(vec![array!(i32, [[1, 2, 3], [1, 2, 3]]), Array::empty()], Err(ArrayError::ConcatenateShapeMismatch)), case(vec![array!(i32, [[1, 2, 3], [1, 2, 3]]), array!(i32, [[5, 6]])], Err(ArrayError::ConcatenateShapeMismatch)), )] fn test_vstack(arrs: Vec, ArrayError>>, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::vstack(arrs)) } #[rstest( arrs, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], array!(i32, [1, 2, 3, 4, 5, 6])), case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6]), array!(i32, [7, 8])], array!(i32, [1, 2, 3, 4, 5, 6, 7, 8])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6], [7, 8]])], array!(i32, [[1, 2, 5, 6], [3, 4, 7, 8]])), case(vec![array!(i32, [[1]]), array!(i32, [[2]]), array!(i32, [[3]])], array!(i32, [[1, 2, 3]])), case(vec![array!(i32, [[1, 2], [1, 2]]), array!(i32, [5, 6])], Err(ArrayError::ConcatenateShapeMismatch)), case(vec![array!(i32, [[1, 2], [1, 2]]), array!(i32, [[5], [6]])], Err(ArrayError::ConcatenateShapeMismatch)), )] fn test_hstack(arrs: Vec, ArrayError>>, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::hstack(arrs)) } #[rstest( arrs, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], array!(i32, [[[1, 4], [2, 5], [3, 6]]])), case(vec![array!(i32, [[1], [2], [3]]), array!(i32, [[4], [5], [6]])], array!(i32, [[[1, 4]], [[2, 5]], [[3, 6]]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6], [7, 8]])], array!(i32, [[[1, 5], [2, 6]], [[3, 7], [4, 8]]])), case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6]), array!(i32, [7, 8])], Err(ArrayError::ConcatenateShapeMismatch)), )] fn test_dstack(arrs: Vec, ArrayError>>, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::dstack(arrs)) } #[rstest( arrs, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], array!(i32, [[1, 4], [2, 5], [3, 6]])), case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6]), array!(i32, [7, 8, 9])], array!(i32, [[1, 4, 7], [2, 5, 8], [3, 6, 9]])), case(vec![array!(i32, [1, 2, 3]), array!(i32, [[4], [5], [6]])], array!(i32, [[1, 4], [2, 5], [3, 6]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [5, 6])], array!(i32, [[1, 2, 5], [3, 4, 6]])), case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5])], Err(ArrayError::ParameterError { param: "arrs", message: "all input arrays must have the same first dimension" })), )] fn test_column_stack(arrs: Vec, ArrayError>>, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::column_stack(arrs)) } #[rstest( arrs, expected, case(vec![array!(i32, [1, 2, 3]), array!(i32, [4, 5, 6])], array!(i32, [[1, 2, 3], [4, 5, 6]])), case(vec![array!(i32, [[1, 2], [3, 4]]), array!(i32, [[5, 6], [7, 8], [9, 10]])], array!(i32, [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])), case(vec![array!(i32, [[1]]), array!(i32, [[2]]), array!(i32, [[3]])], array!(i32, [[1], [2], [3]])), case(vec![array!(i32, [[1, 2, 3], [1, 2, 3]]), Array::empty()], Err(ArrayError::ConcatenateShapeMismatch)), case(vec![array!(i32, [[1, 2, 3], [1, 2, 3]]), array!(i32, [[5, 6]])], Err(ArrayError::ConcatenateShapeMismatch)), )] fn test_row_stack(arrs: Vec, ArrayError>>, expected: Result, ArrayError>) { let arrs = arrs.iter().map(|a| a.as_ref().unwrap().clone()).collect(); assert_eq!(expected, Array::row_stack(arrs)) }