use ndarray_stats::errors::{MultiInputError, ShapeMismatch}; use ndarray_stats::DeviationExt; use approx::assert_abs_diff_eq; use ndarray::{array, Array1}; use num_bigint::BigInt; use num_traits::Float; use std::f64; #[test] fn test_count_eq() -> Result<(), MultiInputError> { let a = array![0., 0.]; let b = array![1., 0.]; let c = array![0., 1.]; let d = array![1., 1.]; assert_eq!(a.count_eq(&a)?, 2); assert_eq!(a.count_eq(&b)?, 1); assert_eq!(a.count_eq(&c)?, 1); assert_eq!(a.count_eq(&d)?, 0); Ok(()) } #[test] fn test_count_neq() -> Result<(), MultiInputError> { let a = array![0., 0.]; let b = array![1., 0.]; let c = array![0., 1.]; let d = array![1., 1.]; assert_eq!(a.count_neq(&a)?, 0); assert_eq!(a.count_neq(&b)?, 1); assert_eq!(a.count_neq(&c)?, 1); assert_eq!(a.count_neq(&d)?, 2); Ok(()) } #[test] fn test_sq_l2_dist() -> Result<(), MultiInputError> { let a = array![0., 1., 4., 2.]; let b = array![1., 1., 2., 4.]; assert_eq!(a.sq_l2_dist(&b)?, 9.); Ok(()) } #[test] fn test_l2_dist() -> Result<(), MultiInputError> { let a = array![0., 1., 4., 2.]; let b = array![1., 1., 2., 4.]; assert_eq!(a.l2_dist(&b)?, 3.); Ok(()) } #[test] fn test_l1_dist() -> Result<(), MultiInputError> { let a = array![0., 1., 4., 2.]; let b = array![1., 1., 2., 4.]; assert_eq!(a.l1_dist(&b)?, 5.); Ok(()) } #[test] fn test_linf_dist() -> Result<(), MultiInputError> { let a = array![0., 0.]; let b = array![1., 0.]; let c = array![1., 2.]; assert_eq!(a.linf_dist(&a)?, 0.); assert_eq!(a.linf_dist(&b)?, 1.); assert_eq!(b.linf_dist(&a)?, 1.); assert_eq!(a.linf_dist(&c)?, 2.); assert_eq!(c.linf_dist(&a)?, 2.); Ok(()) } #[test] fn test_mean_abs_err() -> Result<(), MultiInputError> { let a = array![1., 1.]; let b = array![3., 5.]; assert_eq!(a.mean_abs_err(&a)?, 0.); assert_eq!(a.mean_abs_err(&b)?, 3.); assert_eq!(b.mean_abs_err(&a)?, 3.); Ok(()) } #[test] fn test_mean_sq_err() -> Result<(), MultiInputError> { let a = array![1., 1.]; let b = array![3., 5.]; assert_eq!(a.mean_sq_err(&a)?, 0.); assert_eq!(a.mean_sq_err(&b)?, 10.); assert_eq!(b.mean_sq_err(&a)?, 10.); Ok(()) } #[test] fn test_root_mean_sq_err() -> Result<(), MultiInputError> { let a = array![1., 1.]; let b = array![3., 5.]; assert_eq!(a.root_mean_sq_err(&a)?, 0.); assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 10.0.sqrt()); assert_abs_diff_eq!(b.root_mean_sq_err(&a)?, 10.0.sqrt()); Ok(()) } #[test] fn test_peak_signal_to_noise_ratio() -> Result<(), MultiInputError> { let a = array![1., 1.]; assert!(a.peak_signal_to_noise_ratio(&a, 1.)?.is_infinite()); let a = array![1., 2., 3., 4., 5., 6., 7.]; let b = array![1., 3., 3., 4., 6., 7., 8.]; let maxv = 8.; let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_err(&b)?); let actual = a.peak_signal_to_noise_ratio(&b, maxv)?; assert_abs_diff_eq!(actual, expected); Ok(()) } #[test] fn test_deviations_with_n_by_m_ints() -> Result<(), MultiInputError> { let a = array![[0, 1], [4, 2]]; let b = array![[1, 1], [2, 4]]; assert_eq!(a.count_eq(&a)?, 4); assert_eq!(a.count_neq(&a)?, 0); assert_eq!(a.sq_l2_dist(&b)?, 9); assert_eq!(a.l2_dist(&b)?, 3.); assert_eq!(a.l1_dist(&b)?, 5); assert_eq!(a.linf_dist(&b)?, 2); assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); assert_abs_diff_eq!(a.peak_signal_to_noise_ratio(&b, 4)?, 8.519374645445623); Ok(()) } #[test] fn test_deviations_with_empty_receiver() { let a: Array1 = array![]; let b: Array1 = array![1.]; assert_eq!(a.count_eq(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.count_neq(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.sq_l2_dist(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.l2_dist(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.l1_dist(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.linf_dist(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.mean_abs_err(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.mean_sq_err(&b), Err(MultiInputError::EmptyInput)); assert_eq!(a.root_mean_sq_err(&b), Err(MultiInputError::EmptyInput)); assert_eq!( a.peak_signal_to_noise_ratio(&b, 0.), Err(MultiInputError::EmptyInput) ); } #[test] fn test_deviations_do_not_panic_if_nans() -> Result<(), MultiInputError> { let a: Array1 = array![1., f64::NAN, 3., f64::NAN]; let b: Array1 = array![1., f64::NAN, 3., 4.]; assert_eq!(a.count_eq(&b)?, 2); assert_eq!(a.count_neq(&b)?, 2); assert!(a.sq_l2_dist(&b)?.is_nan()); assert!(a.l2_dist(&b)?.is_nan()); assert!(a.l1_dist(&b)?.is_nan()); assert_eq!(a.linf_dist(&b)?, 0.); assert!(a.mean_abs_err(&b)?.is_nan()); assert!(a.mean_sq_err(&b)?.is_nan()); assert!(a.root_mean_sq_err(&b)?.is_nan()); assert!(a.peak_signal_to_noise_ratio(&b, 0.)?.is_nan()); Ok(()) } #[test] fn test_deviations_with_empty_argument() { let a: Array1 = array![1.]; let b: Array1 = array![]; let shape_mismatch_err = MultiInputError::ShapeMismatch(ShapeMismatch { first_shape: a.shape().to_vec(), second_shape: b.shape().to_vec(), }); let expected_err_usize = Err(shape_mismatch_err.clone()); let expected_err_f64 = Err(shape_mismatch_err); assert_eq!(a.count_eq(&b), expected_err_usize); assert_eq!(a.count_neq(&b), expected_err_usize); assert_eq!(a.sq_l2_dist(&b), expected_err_f64); assert_eq!(a.l2_dist(&b), expected_err_f64); assert_eq!(a.l1_dist(&b), expected_err_f64); assert_eq!(a.linf_dist(&b), expected_err_f64); assert_eq!(a.mean_abs_err(&b), expected_err_f64); assert_eq!(a.mean_sq_err(&b), expected_err_f64); assert_eq!(a.root_mean_sq_err(&b), expected_err_f64); assert_eq!(a.peak_signal_to_noise_ratio(&b, 0.), expected_err_f64); } #[test] fn test_deviations_with_non_copyable() -> Result<(), MultiInputError> { let a: Array1 = array![0.into(), 1.into(), 4.into(), 2.into()]; let b: Array1 = array![1.into(), 1.into(), 2.into(), 4.into()]; assert_eq!(a.count_eq(&a)?, 4); assert_eq!(a.count_neq(&a)?, 0); assert_eq!(a.sq_l2_dist(&b)?, 9.into()); assert_eq!(a.l2_dist(&b)?, 3.); assert_eq!(a.l1_dist(&b)?, 5.into()); assert_eq!(a.linf_dist(&b)?, 2.into()); assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); assert_abs_diff_eq!( a.peak_signal_to_noise_ratio(&b, 4.into())?, 8.519374645445623 ); Ok(()) } #[test] fn test_deviation_computation_for_mixed_ownership() { // It's enough to check that the code compiles! let a = array![0., 0.]; let b = array![1., 0.]; let _ = a.count_eq(&b.view()); let _ = a.count_neq(&b.view()); let _ = a.l2_dist(&b.view()); let _ = a.sq_l2_dist(&b.view()); let _ = a.l1_dist(&b.view()); let _ = a.linf_dist(&b.view()); let _ = a.mean_abs_err(&b.view()); let _ = a.mean_sq_err(&b.view()); let _ = a.root_mean_sq_err(&b.view()); let _ = a.peak_signal_to_noise_ratio(&b.view(), 10.); }