use approx::assert_abs_diff_eq; use ndarray::{arr1, arr2, aview1, stack, Array2, ArrayBase, ArrayView1, Axis, Dim, ViewRepr}; use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt}; use ndarray_stats::DeviationExt; use noisy_float::{checkers::FiniteChecker, NoisyFloat}; use rand_xoshiro::Xoshiro256Plus; use linfa_nn::{distance::*, CommonNearestNeighbour, LinearSearch, NearestNeighbour}; fn sort_by_dist<'a>( mut vec: Vec<(ArrayView1<'a, f64>, usize)>, pt: ArrayView1, ) -> Vec<(ArrayView1<'a, f64>, usize)> { vec.sort_by_key(|v| NoisyFloat::<_, FiniteChecker>::new(v.0.sq_l2_dist(&pt).unwrap())); vec } fn assert_query( output: Vec<(ArrayView1, usize)>, input_data: &Array2, exp_pos: Vec, ) { let (pts, pos): (Vec<_>, Vec<_>) = output.into_iter().unzip(); assert_eq!(pos, exp_pos); assert_abs_diff_eq!( stack(Axis(0), &pts).unwrap(), input_data.select(Axis(0), &exp_pos) ); } fn nn_test_empty(builder: &CommonNearestNeighbour) { let points = Array2::zeros((0, 2)); let nn = builder.from_batch(&points, L2Dist).unwrap(); let out = nn.k_nearest(aview1(&[0.0, 1.0]), 2).unwrap(); assert_eq!(out, Vec::<_>::new()); let out = nn.k_nearest(aview1(&[4.0, 4.0]), 3).unwrap(); assert_eq!(out, Vec::<_>::new()); let pt = aview1(&[6.0, 3.0]); let out = nn.within_range(pt, 9.0).unwrap(); assert_eq!(out, Vec::<_>::new()); } fn nn_test_error(builder: &CommonNearestNeighbour) { let points = Array2::::zeros((4, 0)); assert!(builder.from_batch(&points, L2Dist).is_err()); let points = arr2(&[[0.0, 2.0]]); assert!(builder .from_batch_with_leaf_size(&points, 0, L2Dist) .is_err()); let nn = builder.from_batch(&points, L2Dist).unwrap(); assert!(nn.k_nearest(aview1(&[]), 2).is_err()); assert!(nn.within_range(aview1(&[2.2, 4.4, 5.5]), 4.0).is_err()); } fn nn_test(builder: &CommonNearestNeighbour, sort_within_range: bool) { let points = arr2(&[[0.0, 2.0], [10.0, 4.0], [4.0, 5.0], [7.0, 1.0], [1.0, 7.2]]); let nn = builder.from_batch(&points, L2Dist).unwrap(); let out = nn.k_nearest(aview1(&[0.0, 1.0]), 2).unwrap(); assert_query(out, &points, vec![0, 2]); let out = nn.k_nearest(aview1(&[4.0, 4.0]), 3).unwrap(); assert_query(out, &points, vec![2, 3, 4]); let out = nn.k_nearest(aview1(&[4.0, 4.0]), 10).unwrap(); assert_query(out, &points, vec![2, 3, 4, 0, 1]); let pt = aview1(&[6.0, 3.0]); let mut out = nn.within_range(pt, 4.3).unwrap(); if sort_within_range { out = sort_by_dist(out, pt); } assert_query(out, &points, vec![3, 2, 1]); } fn nn_test_degenerate(builder: &CommonNearestNeighbour) { let points = arr2(&[[0.0, 2.0], [0.0, 2.0], [0.0, 2.0], [0.0, 2.0], [0.0, 2.0]]); let nn = builder.from_batch(&points, L2Dist).unwrap(); let out = nn .k_nearest(aview1(&[0.0, 1.0]), 2) .unwrap() .into_iter() .map(|(p, _)| p.reborrow()) .collect::>(); assert_abs_diff_eq!( stack(Axis(0), &out).unwrap(), arr2(&[[0.0, 2.0], [0.0, 2.0]]) ); let out = nn .k_nearest(aview1(&[4.0, 4.0]), 3) .unwrap() .into_iter() .map(|(p, _)| p.reborrow()) .collect::>(); assert_abs_diff_eq!( stack(Axis(0), &out).unwrap(), arr2(&[[0.0, 2.0], [0.0, 2.0], [0.0, 2.0]]) ); let pt = aview1(&[3.0, 2.0]); let out = nn .within_range(pt, 1.0) .unwrap() .into_iter() .map(|(p, _)| p.reborrow()) .collect::>(); assert_eq!( out, Vec::, Dim<[usize; 1]>>>::new() ); let out = nn .within_range(pt, 20.0) .unwrap() .into_iter() .map(|(p, _)| p.reborrow()) .collect::>(); assert_abs_diff_eq!(stack(Axis(0), &out).unwrap(), points); } fn assert_eq_queries(out1: Vec<(ArrayView1, usize)>, out2: Vec<(ArrayView1, usize)>) { let (pts1, pos1): (Vec<_>, Vec<_>) = out1.into_iter().unzip(); let (pts2, pos2): (Vec<_>, Vec<_>) = out2.into_iter().unzip(); assert_eq!(pos1, pos2); assert_abs_diff_eq!( stack(Axis(0), &pts1).unwrap(), stack(Axis(0), &pts2).unwrap(), ); } fn nn_test_random + Clone>( builder: &CommonNearestNeighbour, dist_fn: D, ) { let mut rng = Xoshiro256Plus::seed_from_u64(40); let n_points = 50000; let n_features = 3; let points = Array2::random_using((n_points, n_features), Uniform::new(-50., 50.), &mut rng); let linear = LinearSearch::new() .from_batch(&points, dist_fn.clone()) .unwrap(); let nn = builder.from_batch(&points, dist_fn).unwrap(); let pt = arr1(&[0., 0., 0.]); assert_eq_queries( nn.k_nearest(pt.view(), 5).unwrap(), linear.k_nearest(pt.view(), 5).unwrap(), ); assert_eq_queries( sort_by_dist(nn.within_range(pt.view(), 15.0).unwrap(), pt.view()), sort_by_dist(linear.within_range(pt.view(), 15.0).unwrap(), pt.view()), ); let pt = arr1(&[-3.4, 10., 0.95]); assert_eq_queries( nn.k_nearest(pt.view(), 30).unwrap(), linear.k_nearest(pt.view(), 30).unwrap(), ); assert_eq_queries( sort_by_dist(nn.within_range(pt.view(), 25.0).unwrap(), pt.view()), sort_by_dist(linear.within_range(pt.view(), 25.0).unwrap(), pt.view()), ); } macro_rules! nn_tests { ($mod:ident, $builder:ident, $sort:expr $(, $_u:ident)?) => { mod $mod { use super::*; #[test] fn empty() { nn_test_empty(&CommonNearestNeighbour::$builder); } #[test] fn error() { nn_test_error(&CommonNearestNeighbour::$builder); } #[test] fn normal() { nn_test(&CommonNearestNeighbour::$builder, $sort); } #[test] fn degenerate() { nn_test_degenerate(&CommonNearestNeighbour::$builder); } $( #[test] fn random_l2() { let $_u: () = (); nn_test_random(&CommonNearestNeighbour::$builder, L2Dist); } #[test] fn random_l1() { nn_test_random(&CommonNearestNeighbour::$builder, L1Dist); } )? } }; } nn_tests!(linear_search, LinearSearch, true); nn_tests!(kdtree, KdTree, false, _u); nn_tests!(balltree, BallTree, false, _u);