#pragma once #include #if THRUST_CPP_DIALECT >= 2014 #include #include #include #include #include #include #include #include #include #include #include // clang-format off // This file contains a set of mix-in classes that define an algorithm // definition for use with test_policy_overloads. The algorithm // definition describes the details of a thrust::async algorithm invocation: // // - Input type and initialization // - Output type and initialization (supports in-place, too) // - Postfix arguments that define the algorithm's overload set // - Abstracted invocation of the async algorithm // - Abstracted invocation of a reference algorithm // - Validation of async vs. reference output // - A description string. // // This definition is used by test_policy_overloads to test each overload // against a reference while injecting a variety of execution policies. This // validates that each overload behaves correctly according to some reference. // // Since much of the algorithm definition is generic and may be reused in // multiple tests with slight changes, a mix-in system is used to simplify // the creation of algorithm definitions. The following namespace hierarchy is // used to organize these generic components: // // * testing::async::mixin:: // ** ::input - Input types/values (device vectors, counting iterators, etc) // ** ::output - Output types/values (device vectors, inplace device vectors, // discard iterators, etc) // ** ::postfix_args - Algorithm specific overload sets // ** ::invoke_reference - Algorithm specific reference invocation // ** ::invoke_async - Algorithm specific async algo invocation // ** ::compare_outputs - Compare output values. // // Each algorithm should define its own `mixins.h` header to declare algorithm // specific mixins (e.g. postfix_args, invoke_reference, and invoke_async) // in a testing::async::::mixins namespace structure. // // For example, the test.async.exclusive_scan.basic test uses the following // algorithm definition from mix-ins: // // ``` // #include // #include // #include // template > // struct basic_invoker // : testing::async::mixin::input::device_vector // , testing::async::mixin::output::device_vector // , testing::async::exclusive_scan::mixin::postfix_args:: // all_overloads // , testing::async::exclusive_scan::mixin::invoke_reference:: // host_synchronous // , testing::async::exclusive_scan::mixin::invoke_async::basic // , testing::async::mixin::compare_outputs::assert_equal_quiet // { // static std::string description() // { // return "basic invocation with device vectors"; // } // }; // // ... // // testing::async::test_policy_overloads>::run(num_values); // ``` // // The basic_invoker class expands to something similar to the following: // // ``` // template > // struct basic_invoker // { // public: // // static std::string description() // { // return "basic invocation with device vectors"; // } // // //------------------------------------------------------------------------- // // testing::async::mixin::input::device_vector // // // // input_type must provide idiomatic definitions of: // // - `using iterator = ...;` // // - `iterator begin() const { ... }` // // - `iterator end() const { ... }` // // - `size_t size() const { ... }` // using input_type = thrust::device_vector; // // // Generate an instance of the input: // static input_type generate_input(std::size_t num_values) // { // input_type input(num_values); // thrust::sequence(input.begin(), input.end(), 25, 3); // return input; // } // // //------------------------------------------------------------------------- // // testing::async::mixin::output::device_vector // // // // output_type must provide idiomatic definitions of: // // - `using iterator = ...;` // // - `iterator begin() { ... }` // using output_type = thrust::device_vector; // // // Generate an instance of the output: // // Might be more complicated, eg. fancy iterators, etc // static output_type generate_output(std::size_t num_values) // { // return output_type(num_values); // } // // //------------------------------------------------------------------------- // // testing::async::exclusive_scan::mixin::postfix_args::all_overloads // using postfix_args_type = std::tuple< // List any extra arg overloads: // std::tuple<>, // - no extra args // std::tuple, // - initial_value // std::tuple // - initial_value, binary_op // >; // // // Create instances of the extra arguments to use when invoking the // // algorithm: // static postfix_args_type generate_postfix_args() // { // return postfix_args_type{ // std::tuple<>{}, // no extra args // std::make_tuple(initial_value_type{42}), // initial_value // // initial_value, binary_op: // std::make_tuple(initial_value_Type{57}, alternate_binary_op{}) // }; // } // // //------------------------------------------------------------------------- // // // testing::async::exclusive_scan::mixin::invoke_reference::host_synchronous // // // // Invoke a reference implementation for a single overload as described by // // postfix_tuple. This tuple contains instances of any trailing arguments // // to pass to the algorithm. The tuple/index_sequence pattern is used to // // support a "no extra args" overload, since the parameter pack expansion // // will do exactly what we want in all cases. // template // static void invoke_reference(input_type const &input, // output_type &output, // PostfixArgTuple &&postfix_tuple, // std::index_sequence) // { // // Create host versions of the input/output: // thrust::host_vector host_input(input.cbegin(), // input.cend()); // thrust::host_vector host_output(host_input.size()); // // // Run host synchronous algorithm to generate reference. // thrust::exclusive_scan(host_input.cbegin(), // host_input.cend(), // host_output.begin(), // std::get( // THRUST_FWD(postfix_tuple))...); // // // Copy back to device. // output = host_output; // } // // //------------------------------------------------------------------------- // // testing::async::mixin::exclusive_scan::mixin::invoke_async::basic // // // // Invoke the async algorithm for a single overload as described by // // the prefix and postfix tuples. These tuples contains instances of any // // additional arguments to pass to the algorithm. The tuple/index_sequence // // pattern is used to support the "no extra args" overload, since the // // parameter pack expansion will do exactly what we want in all cases. // // Prefix args are included here (but not for invoke_reference) to allow // // the test framework to change the execution policy. // // This method must return an event or future. // template // static auto invoke_async(PrefixArgTuple &&prefix_tuple, // std::index_sequence, // input_type const &input, // output_type &output, // PostfixArgTuple &&postfix_tuple, // std::index_sequence) // { // output.resize(input.size()); // auto e = thrust::async::exclusive_scan( // std::get(THRUST_FWD(prefix_tuple))..., // input.cbegin(), // input.cend(), // output.begin(), // std::get(THRUST_FWD(postfix_tuple))...); // return e; // } // // //------------------------------------------------------------------------- // // testing::async::mixin::compare_outputs::assert_equal_quiet // // // // Wait on and validate the event/future (usually with TEST_EVENT_WAIT / // // TEST_FUTURE_VALUE_RETRIEVAL), then check that the reference output // // matches the testing output. // template // static void compare_outputs(EventType &e, // output_type const &ref, // output_type const &test) // { // TEST_EVENT_WAIT(e); // ASSERT_EQUAL_QUIET(ref, test); // } // }; // ``` // // Similar invokers with slight tweaks are used in other // async/exclusive_scan/*.cu tests. // clang-format on namespace testing { namespace async { namespace mixin { //------------------------------------------------------------------------------ namespace input { template struct device_vector { using input_type = thrust::device_vector; static input_type generate_input(std::size_t num_values) { input_type input(num_values); thrust::sequence(input.begin(), input.end(), static_cast(1), static_cast(1)); return input; } }; template struct counting_iterator_from_0 { struct input_type { using iterator = thrust::counting_iterator; std::size_t num_values; iterator begin() const { return iterator{static_cast(0)}; } iterator cbegin() const { return iterator{static_cast(0)}; } iterator end() const { return iterator{static_cast(num_values)}; } iterator cend() const { return iterator{static_cast(num_values)}; } std::size_t size() const { return num_values; } }; static input_type generate_input(std::size_t num_values) { return {num_values}; } }; template struct counting_iterator_from_1 { struct input_type { using iterator = thrust::counting_iterator; std::size_t num_values; iterator begin() const { return iterator{static_cast(1)}; } iterator cbegin() const { return iterator{static_cast(1)}; } iterator end() const { return iterator{static_cast(1 + num_values)}; } iterator cend() const { return iterator{static_cast(1 + num_values)}; } std::size_t size() const { return num_values; } }; static input_type generate_input(std::size_t num_values) { return {num_values}; } }; template struct constant_iterator_1 { struct input_type { using iterator = thrust::constant_iterator; std::size_t num_values; iterator begin() const { return iterator{static_cast(1)}; } iterator cbegin() const { return iterator{static_cast(1)}; } iterator end() const { return iterator{static_cast(1)} + num_values; } iterator cend() const { return iterator{static_cast(1)} + num_values; } std::size_t size() const { return num_values; } }; static input_type generate_input(std::size_t num_values) { return {num_values}; } }; } // namespace input //------------------------------------------------------------------------------ namespace output { template struct device_vector { using output_type = thrust::device_vector; template static output_type generate_output(std::size_t num_values, InputType& /* unused */) { return output_type(num_values); } }; template struct device_vector_reuse_input { using output_type = thrust::device_vector&; template static output_type generate_output(std::size_t /*num_values*/, InputType& input) { return input; } }; struct discard_iterator { struct output_type { using iterator = thrust::discard_iterator<>; iterator begin() const { return thrust::make_discard_iterator(); } iterator cbegin() const { return thrust::make_discard_iterator(); } }; template static output_type generate_output(std::size_t /* num_values */, InputType& /* input */) { return output_type{}; } }; } // namespace output //------------------------------------------------------------------------------ namespace postfix_args { /* Defined per algorithm. Example: * * // Defines several overloads: * // algorithm([policy,] input, output) // no postfix args * // algorithm([policy,] input, output, initial_value) * // algorithm([policy,] input, output, initial_value, binary_op) * template > * struct all_overloads * { * using postfix_args_type = std::tuple< // List any extra arg overloads: * std::tuple<>, // - no extra args * std::tuple, // - initial_value * std::tuple // - initial_value, binary_op * >; * * static postfix_args_type generate_postfix_args() * { * return postfix_args_type{ * std::tuple<>{}, // no extra args * std::make_tuple(initial_value_type{42}), // initial_value * // initial_value, binary_op: * std::make_tuple(initial_value_Type{57}, alternate_binary_op{}) * } * }; * */ } //------------------------------------------------------------------------------ namespace invoke_reference { /* Defined per algorithm. Example: * * template * struct host_synchronous * { * template * static void invoke_reference(InputType const& input, * OutputType& output, * PostfixArgTuple&& postfix_tuple, * std::index_sequence) * { * // Create host versions of the input/output: * thrust::host_vector host_input(input.cbegin(), * input.cend()); * thrust::host_vector host_output(host_input.size()); * * // Run host synchronous algorithm to generate reference. * // Be sure to call a backend that doesn't use the same underlying * // implementation. * thrust::exclusive_scan(host_input.cbegin(), * host_input.cend(), * host_output.begin(), * std::get( * THRUST_FWD(postfix_tuple))...); * * // Copy back to device. * output = host_output; * } * }; * */ // Used to save time when testing unverifiable invocations (discard_iterators) struct noop { template static void invoke_reference(Ts&&...) {} }; } // namespace invoke_reference //------------------------------------------------------------------------------ namespace invoke_async { /* Defined per algorithm. Example: * * struct basic * { * template * static auto invoke_async(PrefixArgTuple&& prefix_tuple, * std::index_sequence, * InputType const& input, * OutputType& output, * PostfixArgTuple&& postfix_tuple, * std::index_sequence) * { * auto e = thrust::async::exclusive_scan( * std::get(THRUST_FWD(prefix_tuple))..., * input.cbegin(), * input.cend(), * output.begin(), * std::get(THRUST_FWD(postfix_tuple))...); * return e; * } * }; */ } // namespace invoke_async //------------------------------------------------------------------------------ namespace compare_outputs { namespace detail { void basic_event_validation(thrust::device_event& e) { TEST_EVENT_WAIT(e); } template void basic_event_validation(thrust::device_future& f) { TEST_FUTURE_VALUE_RETRIEVAL(f); } } // namespace detail struct assert_equal { template static void compare_outputs(EventType& e, OutputType const& ref, OutputType const& test) { detail::basic_event_validation(e); ASSERT_EQUAL(ref, test); } }; struct assert_almost_equal { template static void compare_outputs(EventType& e, OutputType const& ref, OutputType const& test) { detail::basic_event_validation(e); ASSERT_ALMOST_EQUAL(ref, test); } }; // Does an 'almost_equal' comparison for floating point types. Since fp // addition is non-associative, this is sometimes necessary. struct assert_almost_equal_if_fp { private: template static void compare_outputs_impl(EventType& e, OutputType const& ref, OutputType const& test, std::false_type /* is_floating_point */) { detail::basic_event_validation(e); ASSERT_EQUAL(ref, test); } template static void compare_outputs_impl(EventType& e, OutputType const& ref, OutputType const& test, std::true_type /* is_floating_point */) { detail::basic_event_validation(e); ASSERT_ALMOST_EQUAL(ref, test); } public: template static void compare_outputs(EventType& e, OutputType const& ref, OutputType const& test) { using value_type = typename OutputType::value_type; compare_outputs_impl(e, ref, test, std::is_floating_point{}); } }; struct assert_equal_quiet { template static void compare_outputs(EventType& e, OutputType const& ref, OutputType const& test) { detail::basic_event_validation(e); ASSERT_EQUAL_QUIET(ref, test); } }; // Does an 'almost_equal' comparison for floating point types, since fp // addition is non-associative struct assert_almost_equal_if_fp_quiet { private: template static void compare_outputs_impl(EventType& e, OutputType const& ref, OutputType const& test, std::false_type /* is_floating_point */) { detail::basic_event_validation(e); ASSERT_EQUAL_QUIET(ref, test); } template static void compare_outputs_impl(EventType& e, OutputType const& ref, OutputType const& test, std::true_type /* is_floating_point */) { detail::basic_event_validation(e); ASSERT_ALMOST_EQUAL(ref, test); } public: template static void compare_outputs(EventType& e, OutputType const& ref, OutputType const& test) { using value_type = typename OutputType::value_type; compare_outputs_impl(e, ref, test, std::is_floating_point{}); } }; // Used to save time when testing unverifiable invocations (discard_iterators). // Just does basic validation of the future/event. struct noop { template static void compare_outputs(EventType &e, Ts&&...) { detail::basic_event_validation(e); } }; } // namespace compare_outputs } // namespace mixin } // namespace async } // namespace testing #endif // C++14