#![feature(try_from)] #[macro_use] extern crate tensor_macros; use tensor_macros::traits::*; use std::convert::TryFrom; tensor!(T2345: 2 x 3 x 4 x 5); #[test] fn tensor_dims() { assert_eq!(T2345::::SIZE, 2 * 3 * 4 * 5); assert_eq!(T2345::::NDIM, 4); } tensor!(M23: 2 x 3); #[test] fn matrix_dims() { assert_eq!(M23::::ROWS, 2); assert_eq!(M23::::COLS, 3); } tensor!(V4: 4); #[test] fn col_vector_size() { assert_eq!(V4::::COLS, 4); } tensor!(V2Row: row 2); #[test] fn row_vector_size() { assert_eq!(V2Row::::ROWS, 2); } tensor!(T324: 3 x 2 x 4); #[test] fn dims() { assert_eq!(T324::::dims(), vec!(3, 2, 4)); let t324: T324 = Default::default(); assert_eq!(t324.get_dims(), vec!(3, 2, 4)); } #[test] fn try_from_vec() { let t324 = T324::::try_from(vec![ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, ]); let exp = T324([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, ]); assert_eq!(t324, Ok(exp)); } #[test] fn index() { let t324 = T324::::try_from(vec![ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, ]) .unwrap(); assert_eq!(t324[(0, 0, 0)], 0); assert_eq!(t324[(1, 1, 1)], 13); assert_eq!(t324[(2, 1, 3)], 23); assert_eq!(t324[15], 15); } tensor!(T243: 2 x 4 x 3); tensor!(M43: 4 x 3 x 1); tensor!(V2: 2 x 1); dot!(T243: 2 x 4 x 3 * M43: 4 x 3 x 1 => V2: 2 x 1); #[test] fn dot_test() { let l = T243::([ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, ]); let r = M43([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]); assert_eq!(l * r, V2([506.0, 1298.0])); } #[test] fn debug() { let t = T2345::try_from((0u8..120).collect::>()).unwrap(); let output = "0\t1\t2\t3\t4\t 5\t6\t7\t8\t9\t 10\t11\t12\t13\t14\t 15\t16\t17\t18\t19\t 20\t21\t22\t23\t24\t 25\t26\t27\t28\t29\t 30\t31\t32\t33\t34\t 35\t36\t37\t38\t39\t 40\t41\t42\t43\t44\t 45\t46\t47\t48\t49\t 50\t51\t52\t53\t54\t 55\t56\t57\t58\t59\t 60\t61\t62\t63\t64\t 65\t66\t67\t68\t69\t 70\t71\t72\t73\t74\t 75\t76\t77\t78\t79\t 80\t81\t82\t83\t84\t 85\t86\t87\t88\t89\t 90\t91\t92\t93\t94\t 95\t96\t97\t98\t99\t 100\t101\t102\t103\t104\t 105\t106\t107\t108\t109\t 110\t111\t112\t113\t114\t 115\t116\t117\t118\t119\t "; assert_eq!(format!("{:?}", t), output); } #[test] fn cwise() { let l = T243::try_from((0u64..24).collect::>()).unwrap(); let r = T243::from(5); assert_eq!( l.cwise_mul(l.cwise_mul(r)), T243::try_from((0u64..24).map(|x| x * x * 5).collect::>()).unwrap() ); } transpose!(T243: 2 x 4 x 3 => T243T); #[test] fn index_transpose() { let t = T243::try_from((0u8..24).collect::>()).unwrap(); assert_eq!(t[(1, 2, 2)], t.transpose()[(2, 2, 1)]); } #[test] fn debug_transpose() { let t = T243::try_from((0u8..24).collect::>()).unwrap(); let output = "0\t1\t2\t 3\t4\t5\t 6\t7\t8\t 9\t10\t11\t 12\t13\t14\t 15\t16\t17\t 18\t19\t20\t 21\t22\t23\t "; assert_eq!(format!("{:?}", t), output); let u = t.transpose(); let output = "0\t12\t 3\t15\t 6\t18\t 9\t21\t 1\t13\t 4\t16\t 7\t19\t 10\t22\t 2\t14\t 5\t17\t 8\t20\t 11\t23\t "; assert_eq!(format!("{:?}", u), output); } transpose!(M23: 2 x 3 => M32); tensor!(M33: 3 x 3); dot!(M32: 3 x 2 * M23: 2 x 3 => M33: 3 x 3); #[test] fn transpose_dot() { let t = M23::try_from((0u8..6).collect::>()).unwrap(); let u = t.transpose(); let v = M33([9, 12, 15, 12, 17, 22, 15, 22, 29]); assert_eq!(u * t, v); }