#include #if THRUST_CPP_DIALECT >= 2014 #include #include #include #include #include #include #include #include #include // This test is an adaptation of TestInclusiveScanWithBigIndices from scan.cu. namespace { // Fake iterator that asserts // (a) it is written with a sequence and // (b) a defined maximum value is written at some point // // This allows us to test very large problem sizes without actually allocating // large amounts of memory that would exceed most devices' capacity. struct assert_sequence_iterator { using value_type = std::int64_t; using difference_type = std::int64_t; // Defined for thrust::iterator_traits: using pointer = value_type *; using reference = assert_sequence_iterator; // weird but convenient using iterator_category = typename thrust::detail::iterator_facade_category< thrust::device_system_tag, thrust::random_access_traversal_tag, value_type, reference>::type; std::int64_t expected{0}; std::int64_t max{0}; mutable thrust::device_ptr found_max{nullptr}; mutable thrust::device_ptr unexpected_value{nullptr}; // Should be called on the first iterator generated. This needs to be done // explicitly from the host. void initialize_shared_state() { found_max = thrust::device_malloc(1); unexpected_value = thrust::device_malloc(1); *found_max = false; *unexpected_value = false; } // Should be called only once on the initialized iterator. This needs to be // done explicitly from the host. void free_shared_state() const { thrust::device_free(found_max); thrust::device_free(unexpected_value); found_max = nullptr; unexpected_value = nullptr; } __host__ __device__ assert_sequence_iterator operator+(difference_type i) const { return clone(expected + i); } __host__ __device__ reference operator[](difference_type i) const { return clone(expected + i); } // Some weirdness, this iterator acts like its own reference __device__ assert_sequence_iterator operator=(value_type val) { if (val != expected) { printf("Error: expected %lld, got %lld\n", expected, val); *unexpected_value = true; } else if (val == max) { *found_max = true; } return *this; } private: __host__ __device__ assert_sequence_iterator clone(value_type new_expected) const { return {new_expected, max, found_max, unexpected_value}; } }; // output mixin that generates assert_sequence_iterators. // Must be paired with validate_assert_sequence_iterators mixin to free // shared state. struct assert_sequence_output { struct output_type { using iterator = assert_sequence_iterator; iterator iter; explicit output_type(iterator &&it) : iter{std::move(it)} { iter.initialize_shared_state(); } ~output_type() { iter.free_shared_state(); } iterator begin() { return iter; } }; template static output_type generate_output(std::size_t num_values, InputType &) { using value_type = typename assert_sequence_iterator::value_type; assert_sequence_iterator it{1, static_cast(num_values), nullptr, nullptr}; return output_type{std::move(it)}; } }; struct validate_assert_sequence_iterators { using output_t = assert_sequence_output::output_type; template static void compare_outputs(EventType &e, output_t const &, output_t const &test) { testing::async::mixin::compare_outputs::detail::basic_event_validation(e); ASSERT_EQUAL(*test.iter.unexpected_value, false); ASSERT_EQUAL(*test.iter.found_max, true); } }; //------------------------------------------------------------------------------ // Overloads without custom binary operators use thrust::plus<>, so use // constant input iterator to generate the output sequence: struct default_bin_op_overloads { using postfix_args_type = std::tuple< // List any extra arg overloads: std::tuple<> // - no extra args >; static postfix_args_type generate_postfix_args() { return std::tuple>{}; } }; struct default_bin_op_invoker : testing::async::mixin::input::constant_iterator_1 , assert_sequence_output , default_bin_op_overloads , testing::async::mixin::invoke_reference::noop , testing::async::inclusive_scan::mixin::invoke_async::simple , validate_assert_sequence_iterators { static std::string description() { return "test large array indices with default binary operator"; } }; } // end anon namespace void test_large_indices_default_scan_op() { // Test problem sizes around signed/unsigned int max: testing::async::test_policy_overloads::run(1ll << 30); testing::async::test_policy_overloads::run(1ll << 31); testing::async::test_policy_overloads::run(1ll << 32); testing::async::test_policy_overloads::run(1ll << 33); } DECLARE_UNITTEST(test_large_indices_default_scan_op); namespace { //------------------------------------------------------------------------------ // Generate the output sequence using counting iterators and thrust::max<> for // custom operator overloads. struct custom_bin_op_overloads { using postfix_args_type = std::tuple< // List any extra arg overloads: std::tuple> // - custom binary op >; static postfix_args_type generate_postfix_args() { return postfix_args_type{std::make_tuple(thrust::maximum<>{})}; } }; struct custom_bin_op_invoker : testing::async::mixin::input::counting_iterator_from_1 , assert_sequence_output , custom_bin_op_overloads , testing::async::mixin::invoke_reference::noop , testing::async::inclusive_scan::mixin::invoke_async::simple , validate_assert_sequence_iterators { static std::string description() { return "test large array indices with custom binary operator"; } }; } // end anon namespace void test_large_indices_custom_scan_op() { // Test problem sizes around signed/unsigned int max: testing::async::test_policy_overloads::run(1ll << 30); testing::async::test_policy_overloads::run(1ll << 31); testing::async::test_policy_overloads::run(1ll << 32); testing::async::test_policy_overloads::run(1ll << 33); } DECLARE_UNITTEST(test_large_indices_custom_scan_op); #endif // C++14