// Copyright 2021 Daniel Philip Watson // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; use std::convert::TryFrom; use simdprune::*; fn slow_prune(input: &[T], mask: i32) -> Vec { let mut idx = 0; let mut expected = input.to_vec(); expected.retain(|_| { let flag = (mask as u32 >> idx) & 1 == 0; idx += 1; flag }); expected } fn test_128(func: unsafe fn(__m128i, i32) -> __m128i, length: usize) where T: Default + Copy + std::cmp::PartialEq + std::fmt::Debug + TryFrom, >::Error: std::fmt::Debug, { let mut buf = vec![T::default(); length]; let input: Vec<_> = (0..length).map(|x| T::try_from(x).unwrap()).collect(); for mask in 0..1 << length { unsafe { let input_vec = _mm_loadu_si128(input.as_ptr().cast()); _mm_storeu_si128(buf.as_mut_ptr().cast(), func(input_vec, mask)); let expected = slow_prune(&input, mask); let result = &buf[..mask.count_zeros() as usize - (32 - length)]; assert_eq!(expected, result, "\n mask: {:#0w$b}", mask, w = length + 2); } } } fn test_256(func: unsafe fn(__m256i, i32) -> __m256i, length: usize) where T: Default + Copy + std::cmp::PartialEq + std::fmt::Debug + TryFrom, >::Error: std::fmt::Debug, { let mut buf = vec![T::default(); length]; let input: Vec<_> = (0..length).map(|x| T::try_from(x).unwrap()).collect(); for mask in 0..1 << length { unsafe { let input_vec = _mm256_loadu_si256(input.as_ptr().cast()); _mm256_storeu_si256(buf.as_mut_ptr().cast(), func(input_vec, mask)); let expected = slow_prune(&input, mask); let result = &buf[..mask.count_zeros() as usize - (32 - length)]; assert_eq!(expected, result, "\n mask: {:#0w$b}", mask, w = length + 2); } } } #[test] #[cfg(feature = "large_tables")] #[ignore] // expensive fn test_8() { test_128::(prune_epi8, 16); } #[test] #[ignore] // expensive fn test_thin_8() { test_128::(thinprune_epi8, 16); } #[test] #[ignore] // expensive fn test_skinny_8() { test_128::(skinnyprune_epi8, 16); } #[test] fn test_16() { test_128::(prune_epi16, 8); } #[test] fn test_32() { test_128::(prune_epi32, 4); } #[test] fn test256_32() { test_256::(prune256_epi32, 8); }