use zyx::{Tensor, ZyxError}; #[test] fn matmul() -> Result<(), ZyxError> { let x = Tensor::from([[2, 4, 3], [1, 5, 1]]); let y = Tensor::from([[2, 4], [3, 1], [5, 1]]); let z = x.dot(y)?; assert_eq!(z, [[31, 15], [22, 10]]); Ok(()) } #[test] fn pad_reduce() -> Result<(), ZyxError> { let mut x = Tensor::from([[2, 4, 3], [1, 5, 1]]); x = x.sum(1)?; x = x.pad_zeros([(0, 1)])?; assert_eq!(x, [9, 7, 0]); Ok(()) } #[test] fn permute_pad() -> Result<(), ZyxError> { let mut x = Tensor::from([[2, 4, 3], [1, 5, 1]]); x = x.pad_zeros([(1, 0)])?.t(); assert_eq!(x, [[0, 0], [2, 1], [4, 5], [3, 1]]); Ok(()) } #[test] fn expand_reduce() -> Result<(), ZyxError> { let mut x = Tensor::from([[2, 4, 3], [1, 5, 1]]); x = x.sum(1)?; let y = x.expand([2, 2])?; x = x.reshape([2, 1])?.expand([2, 2])?; Tensor::realize([&x, &y])?; assert_eq!(y, [[9, 7], [9, 7]]); assert_eq!(x, [[9, 9], [7, 7]]); Ok(()) } #[test] fn pad_reshape_expand() -> Result<(), ZyxError> { let mut x = Tensor::from([[2, 4, 3, 3, 4], [1, 2, 1, 5, 1]]); x = x.pad_zeros([(1, 0), (2, 1)])?; x = x.reshape([2, 1, 3, 5])?; x = x.expand([2, 2, 3, 5])?; assert_eq!( x, [ [ [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 2, 4]], [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 2, 4]] ], [ [[3, 3, 4, 0, 1], [2, 1, 5, 1, 0], [0, 0, 0, 0, 0]], [[3, 3, 4, 0, 1], [2, 1, 5, 1, 0], [0, 0, 0, 0, 0]] ] ] ); Ok(()) } #[test] fn pool() -> Result<(), ZyxError> { let mut x = Tensor::from((0..9).collect::>()).reshape((3, 3))?; //x = x.repeat([2, 2]); //println!("{x}"); //x = x.reshape([12, 3]); //println!("{x}"); x = x.pool([2, 2], 1, 1)?; assert_eq!( x, [ [[[0, 1], [3, 4]], [[1, 2], [4, 5]]], [[[3, 4], [6, 7]], [[4, 5], [7, 8]]] ] ); //println!("{x}"); Ok(()) } #[test] fn cumsum() -> Result<(), ZyxError> { let mut x = Tensor::from((0..9).collect::>()).reshape((3, 3))?; x = x.cumsum(1)?; assert_eq!(x, [[0, 1, 3], [3, 7, 12], [6, 13, 21]]); Ok(()) } #[test] fn arange() -> Result<(), ZyxError> { let x = Tensor::arange(0, 10, 2)?; //println!("{x}"); assert_eq!(x, [0, 2, 4, 6, 8]); Ok(()) } /*#[test] fn rand() { use zyx::DType; let x = Tensor::randn([10, 10], DType::F32).unwrap(); //Tensor::plot_graph([], "graph0"); //Tensor::realize([&x]).unwrap(); println!("{x}"); }*/ #[test] fn const_() -> Result<(), ZyxError> { let x = Tensor::from([[3f32, 4., 2.], [4., 3., 2.]]); let mut y = Tensor::constant(1) + x; //.get(1); println!("{y}'"); //Tensor::plot_graph([], "graph0"); //let c: Tensor = Tensor::constant(1f64 / std::f64::consts::E.log2()); //y = y.log2() * c.cast(y.dtype()); y = y.ln(); println!("{y}'"); Ok(()) } #[test] fn graph_shapes() -> Result<(), ZyxError> { let x = Tensor::constant(2); let y = x.expand([1, 1])?; println!("{y}"); Ok(()) } #[test] fn uni_matmul() -> Result<(), ZyxError> { //use zyx::DType; //let x = Tensor::rand([5, 5], DType::F32) * 2f32 + 3f32; //let y = Tensor::rand([5, 5], DType::F32) * 3f32 + 4f32; let x = Tensor::uniform([5, 5], -1f32..2f32)?; let y = Tensor::uniform([5, 5], -1f32..5f32)?; let z = x.dot(y)?; println!("{z}"); Ok(()) } #[test] fn cat() -> Result<(), ZyxError> { let a = Tensor::from([[1, 2], [3, 4]]); let b = Tensor::from([[5, 6], [7, 8]]); let c = Tensor::cat([&a, &b], 0)?; assert_eq!(c, [[1, 2], [3, 4], [5, 6], [7, 8]]); let c = Tensor::cat([&a, &b], 1)?; assert_eq!(c, [[1, 2, 5, 6], [3, 4, 7, 8]]); Ok(()) } #[test] fn matmul_1024() -> Result<(), ZyxError> { //let mut xy: Vec = Tensor::load("xy.safetensors").unwrap(); //let y = xy.pop().unwrap(); //let x = xy.pop().unwrap(); let mut xyz: Vec = Tensor::load("xyz.safetensors")?; let z = xyz.pop().unwrap(); let y = xyz.pop().unwrap(); let x = xyz.pop().unwrap(); println!("{:?}", x.shape()); println!("{:?}", y.shape()); let dataz: Vec = z.try_into()?; let zz = x.matmul(y)?; let datazz: Vec = zz.try_into()?; for (x, y) in dataz.iter().zip(datazz) { //println!("{x}, {y}"); assert!((x - y).abs() < 0.01); } //println!("{z}"); Ok(()) } /*#[test] fn softmax() { let x = Tensor::from([2f32, 4., 3.]); //let y = x.softmax([]); let y = x.max_kd([]); //println!("{y:?}"); let e = (x - y).exp(); //println!("{e:?}"); //panic!(); let y = &e / e.sum_kd([]); //Tensor::plot_graph([], "graph").unwrap(); println!("{y:?}"); }*/