/// Solve least square problem `|b - Ax|` with multi-column `b` use approx::AbsDiffEq; use ndarray::*; use ndarray_linalg::*; /// A is square. `x = A^{-1} b`, `|b - Ax| = 0` fn test_exact(a: Array2, b: Array2) { assert_eq!(a.layout().unwrap().size(), (3, 3)); assert_eq!(b.layout().unwrap().size(), (3, 2)); let result = a.least_squares(&b).unwrap(); dbg!(&result); // unpack result let x: Array2 = result.solution; let residual_l2_square: Array1 = result.residual_sum_of_squares.unwrap(); // must be full-rank assert_eq!(result.rank, 3); // |b - Ax| == 0 for &residual in &residual_l2_square { assert!(residual < T::real(1.0e-4)); } // b == Ax let ax = a.dot(&x); assert_close_l2!(&b, &ax, T::real(1.0e-4)); } macro_rules! impl_exact { ($scalar:ty) => { paste::item! { #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 3), &mut rng); let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_exact(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 3), &mut rng); let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_exact(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 3).f(), &mut rng); let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_exact(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 3).f(), &mut rng); let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_exact(a, b) } } }; } impl_exact!(f32); impl_exact!(f64); impl_exact!(c32); impl_exact!(c64); /// #column < #row case. /// Linear problem is overdetermined, `|b - Ax| > 0`. fn test_overdetermined(a: Array2, bs: Array2) where T::Real: AbsDiffEq, { assert_eq!(a.layout().unwrap().size(), (4, 3)); assert_eq!(bs.layout().unwrap().size(), (4, 2)); let result = a.least_squares(&bs).unwrap(); // unpack result let xs = result.solution; let residual_l2_square = result.residual_sum_of_squares.unwrap(); // Must be full-rank assert_eq!(result.rank, 3); for j in 0..2 { let b = bs.index_axis(Axis(1), j); let x = xs.index_axis(Axis(1), j); let residual = &b - &a.dot(&x); let residual_l2_sq = residual_l2_square[j]; assert!(residual_l2_sq.abs_diff_eq(&residual.norm_l2().powi(2), T::real(1.0e-4))); // `|residual| < |b|` assert!(residual.norm_l2() < b.norm_l2()); } } macro_rules! impl_overdetermined { ($scalar:ty) => { paste::item! { #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((4, 3), &mut rng); let b: Array2<$scalar> = random_using((4, 2), &mut rng); test_overdetermined(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((4, 3).f(), &mut rng); let b: Array2<$scalar> = random_using((4, 2), &mut rng); test_overdetermined(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((4, 3), &mut rng); let b: Array2<$scalar> = random_using((4, 2).f(), &mut rng); test_overdetermined(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((4, 3).f(), &mut rng); let b: Array2<$scalar> = random_using((4, 2).f(), &mut rng); test_overdetermined(a, b) } } }; } impl_overdetermined!(f32); impl_overdetermined!(f64); impl_overdetermined!(c32); impl_overdetermined!(c64); /// #column > #row case. /// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique fn test_underdetermined(a: Array2, b: Array2) { assert_eq!(a.layout().unwrap().size(), (3, 4)); assert_eq!(b.layout().unwrap().size(), (3, 2)); let result = a.least_squares(&b).unwrap(); assert_eq!(result.rank, 3); assert!(result.residual_sum_of_squares.is_none()); // b == Ax let x = result.solution; let ax = a.dot(&x); assert_close_l2!(&b, &ax, T::real(1.0e-4)); } macro_rules! impl_underdetermined { ($scalar:ty) => { paste::item! { #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 4), &mut rng); let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_underdetermined(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 4).f(), &mut rng); let b: Array2<$scalar> = random_using((3, 2), &mut rng); test_underdetermined(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 4), &mut rng); let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_underdetermined(a, b) } #[test] fn []() { let mut rng = rand_pcg::Mcg128Xsl64::new(0xcafef00dd15ea5e5); let a: Array2<$scalar> = random_using((3, 4).f(), &mut rng); let b: Array2<$scalar> = random_using((3, 2).f(), &mut rng); test_underdetermined(a, b) } } }; } impl_underdetermined!(f32); impl_underdetermined!(f64); impl_underdetermined!(c32); impl_underdetermined!(c64);