/* Appellation: tensor Contrib: FL03 */ #![cfg(test)] extern crate acme_tensor as acme; use acme::prelude::{IntoShape, Tensor}; #[test] fn test_tensor() { let shape = (2, 2); let a = Tensor::::ones(shape); let b = a.zeros_like(); assert_ne!(a.id(), b.id()); assert_eq!(a.shape(), b.shape()); assert_eq!(a.size(), b.size()); assert_eq!(a.strides(), b.strides()); } #[test] #[ignore = "reason"] fn test_raw_tensor() { let shape = (2, 2).into_shape(); let stride = shape.default_strides(); let mut data = vec![1f64, 2f64, 3f64, 4f64]; let exp = Tensor::::from_shape_vec(shape.clone(), data.clone()); unsafe { let ptr = data.as_mut_ptr(); let a = Tensor::::from_raw_parts(ptr, shape, stride); assert_eq!(a, exp); } } #[test] fn test_scalar_tensor() { use num::{One, Zero}; let one = Tensor::::one(); let zero = Tensor::::zero(); assert!(one.is_scalar()); assert!(zero.is_scalar()); } #[test] fn test_first_and_last() { let shape = (3, 3); let tensor = Tensor::linspace(0f64, 9f64, 9).reshape(shape).unwrap(); assert_eq!(tensor.first(), Some(&0f64)); assert_eq!(tensor.last(), Some(&8f64)); let shape = (3, 3, 1); let tensor = Tensor::linspace(0f64, 9f64, 9).reshape(shape).unwrap(); assert_eq!(tensor.first(), Some(&0f64)); assert_eq!(tensor.last(), Some(&8f64)); } #[test] fn test_index() { let shape = (2, 3).into_shape(); let n = shape.size(); let a = Tensor::::linspace(0f64, n as f64, n) .reshape(shape.clone()) .unwrap(); assert_eq!(a[[0, 0]], 0f64); assert_eq!(a[&[0, 1]], 1f64); assert_eq!(a[shape.get_final_position()], 5f64); } #[test] fn test_higher_dim() { let shape = (2, 2, 2, 2); let a = Tensor::::ones(shape); let b = a.zeros_like(); assert_ne!(a.id(), b.id()); assert_eq!(a.shape(), b.shape()); assert_eq!(a.size(), b.size()); assert_eq!(a.strides(), b.strides()); assert_eq!(a.strides().len(), 4); } #[test] fn test_sum() { let shape = (2, 2).into_shape(); let a = Tensor::fill(shape, 1f64); assert_eq!(a.sum(), 4.0); } #[test] fn test_product() { let shape = (2, 2).into_shape(); let a = Tensor::fill(shape, 2f64); assert_eq!(a.product(), 16.0); }