/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * Simple binary operator functor types */ /****************************************************************************** * Simple functor operators ******************************************************************************/ #pragma once #include "../config.cuh" #include "../util_type.cuh" CUB_NAMESPACE_BEGIN /** * \addtogroup UtilModule * @{ */ /** * \brief Default equality functor */ struct Equality { /// Boolean equality operator, returns (a == b) template __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const { return a == b; } }; /** * \brief Default inequality functor */ struct Inequality { /// Boolean inequality operator, returns (a != b) template __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const { return a != b; } }; /** * \brief Inequality functor (wraps equality functor) */ template struct InequalityWrapper { /// Wrapped equality operator EqualityOp op; /// Constructor __host__ __device__ __forceinline__ InequalityWrapper(EqualityOp op) : op(op) {} /// Boolean inequality operator, returns (a != b) template __host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) { return !op(a, b); } }; /** * \brief Default sum functor */ struct Sum { /// Binary sum operator, returns a + b template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a + b; } }; /** * \brief Default difference functor */ struct Difference { /// Binary difference operator, returns a - b template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a - b; } }; /** * \brief Default division functor */ struct Division { /// Binary difference operator, returns a - b template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a / b; } }; /** * \brief Default max functor */ struct Max { /// Boolean max operator, returns (a > b) ? a : b template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return CUB_MAX(a, b); } }; /** * \brief Arg max functor (keeps the value and offset of the first occurrence of the larger item) */ struct ArgMax { /// Boolean max operator, preferring the item having the smaller offset in case of ties template __host__ __device__ __forceinline__ KeyValuePair operator()( const KeyValuePair &a, const KeyValuePair &b) const { // Mooch BUG (device reduce argmax gk110 3.2 million random fp32) // return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) return b; return a; } }; /** * \brief Default min functor */ struct Min { /// Boolean min operator, returns (a < b) ? a : b template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return CUB_MIN(a, b); } }; /** * \brief Arg min functor (keeps the value and offset of the first occurrence of the smallest item) */ struct ArgMin { /// Boolean min operator, preferring the item having the smaller offset in case of ties template __host__ __device__ __forceinline__ KeyValuePair operator()( const KeyValuePair &a, const KeyValuePair &b) const { // Mooch BUG (device reduce argmax gk110 3.2 million random fp32) // return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) return b; return a; } }; /** * \brief Default cast functor */ template struct CastOp { /// Cast operator, returns (B) a template __host__ __device__ __forceinline__ B operator()(const A &a) const { return (B) a; } }; /** * \brief Binary operator wrapper for switching non-commutative scan arguments */ template class SwizzleScanOp { private: /// Wrapped scan operator ScanOp scan_op; public: /// Constructor __host__ __device__ __forceinline__ SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) {} /// Switch the scan arguments template __host__ __device__ __forceinline__ T operator()(const T &a, const T &b) { T _a(a); T _b(b); return scan_op(_b, _a); } }; /** * \brief Reduce-by-segment functor. * * Given two cub::KeyValuePair inputs \p a and \p b and a * binary associative combining operator \p f(const T &x, const T &y), * an instance of this functor returns a cub::KeyValuePair whose \p key * field is a.key + b.key, and whose \p value field * is either b.value if b.key is non-zero, or f(a.value, b.value) otherwise. * * ReduceBySegmentOp is an associative, non-commutative binary combining operator * for input sequences of cub::KeyValuePair pairings. Such * sequences are typically used to represent a segmented set of values to be reduced * and a corresponding set of {0,1}-valued integer "head flags" demarcating the * first value of each segment. * */ template ///< Binary reduction operator to apply to values struct ReduceBySegmentOp { /// Wrapped reduction operator ReductionOpT op; /// Constructor __host__ __device__ __forceinline__ ReduceBySegmentOp() {} /// Constructor __host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) : op(op) {} /// Scan operator template ///< KeyValuePair pairing of T (value) and OffsetT (head flag) __host__ __device__ __forceinline__ KeyValuePairT operator()( const KeyValuePairT &first, ///< First partial reduction const KeyValuePairT &second) ///< Second partial reduction { KeyValuePairT retval; retval.key = first.key + second.key; retval.value = (second.key) ? second.value : // The second partial reduction spans a segment reset, so it's value aggregate becomes the running aggregate op(first.value, second.value); // The second partial reduction does not span a reset, so accumulate both into the running aggregate return retval; } }; template ///< Binary reduction operator to apply to values struct ReduceByKeyOp { /// Wrapped reduction operator ReductionOpT op; /// Constructor __host__ __device__ __forceinline__ ReduceByKeyOp() {} /// Constructor __host__ __device__ __forceinline__ ReduceByKeyOp(ReductionOpT op) : op(op) {} /// Scan operator template __host__ __device__ __forceinline__ KeyValuePairT operator()( const KeyValuePairT &first, ///< First partial reduction const KeyValuePairT &second) ///< Second partial reduction { KeyValuePairT retval = second; if (first.key == second.key) retval.value = op(first.value, retval.value); return retval; } }; template struct BinaryFlip { BinaryOpT binary_op; __device__ __host__ explicit BinaryFlip(BinaryOpT binary_op) : binary_op(binary_op) {} template __device__ auto operator()(T &&t, U &&u) -> decltype(binary_op(std::forward(u), std::forward(t))) { return binary_op(std::forward(u), std::forward(t)); } }; template __device__ __host__ BinaryFlip MakeBinaryFlip(BinaryOpT binary_op) { return BinaryFlip(binary_op); } /** @} */ // end group UtilModule CUB_NAMESPACE_END