/******************************************************************************
* Copyright (c) 2011, Duane Merrill. All rights reserved.
* Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* Simple binary operator functor types
*/
/******************************************************************************
* Simple functor operators
******************************************************************************/
#pragma once
#include "../config.cuh"
#include "../util_type.cuh"
CUB_NAMESPACE_BEGIN
/**
* \addtogroup UtilModule
* @{
*/
/**
* \brief Default equality functor
*/
struct Equality
{
/// Boolean equality operator, returns (a == b)
template
__host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const
{
return a == b;
}
};
/**
* \brief Default inequality functor
*/
struct Inequality
{
/// Boolean inequality operator, returns (a != b)
template
__host__ __device__ __forceinline__ bool operator()(const T &a, const T &b) const
{
return a != b;
}
};
/**
* \brief Inequality functor (wraps equality functor)
*/
template
struct InequalityWrapper
{
/// Wrapped equality operator
EqualityOp op;
/// Constructor
__host__ __device__ __forceinline__
InequalityWrapper(EqualityOp op) : op(op) {}
/// Boolean inequality operator, returns (a != b)
template
__host__ __device__ __forceinline__ bool operator()(const T &a, const T &b)
{
return !op(a, b);
}
};
/**
* \brief Default sum functor
*/
struct Sum
{
/// Binary sum operator, returns a + b
template
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
{
return a + b;
}
};
/**
* \brief Default difference functor
*/
struct Difference
{
/// Binary difference operator, returns a - b
template
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
{
return a - b;
}
};
/**
* \brief Default division functor
*/
struct Division
{
/// Binary difference operator, returns a - b
template
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
{
return a / b;
}
};
/**
* \brief Default max functor
*/
struct Max
{
/// Boolean max operator, returns (a > b) ? a : b
template
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
{
return CUB_MAX(a, b);
}
};
/**
* \brief Arg max functor (keeps the value and offset of the first occurrence of the larger item)
*/
struct ArgMax
{
/// Boolean max operator, preferring the item having the smaller offset in case of ties
template
__host__ __device__ __forceinline__ KeyValuePair operator()(
const KeyValuePair &a,
const KeyValuePair &b) const
{
// Mooch BUG (device reduce argmax gk110 3.2 million random fp32)
// return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a;
if ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key)))
return b;
return a;
}
};
/**
* \brief Default min functor
*/
struct Min
{
/// Boolean min operator, returns (a < b) ? a : b
template
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const
{
return CUB_MIN(a, b);
}
};
/**
* \brief Arg min functor (keeps the value and offset of the first occurrence of the smallest item)
*/
struct ArgMin
{
/// Boolean min operator, preferring the item having the smaller offset in case of ties
template
__host__ __device__ __forceinline__ KeyValuePair operator()(
const KeyValuePair &a,
const KeyValuePair &b) const
{
// Mooch BUG (device reduce argmax gk110 3.2 million random fp32)
// return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a;
if ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key)))
return b;
return a;
}
};
/**
* \brief Default cast functor
*/
template
struct CastOp
{
/// Cast operator, returns (B) a
template
__host__ __device__ __forceinline__ B operator()(const A &a) const
{
return (B) a;
}
};
/**
* \brief Binary operator wrapper for switching non-commutative scan arguments
*/
template
class SwizzleScanOp
{
private:
/// Wrapped scan operator
ScanOp scan_op;
public:
/// Constructor
__host__ __device__ __forceinline__
SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op) {}
/// Switch the scan arguments
template
__host__ __device__ __forceinline__
T operator()(const T &a, const T &b)
{
T _a(a);
T _b(b);
return scan_op(_b, _a);
}
};
/**
* \brief Reduce-by-segment functor.
*
* Given two cub::KeyValuePair inputs \p a and \p b and a
* binary associative combining operator \p f(const T &x, const T &y),
* an instance of this functor returns a cub::KeyValuePair whose \p key
* field is a.key + b.key, and whose \p value field
* is either b.value if b.key is non-zero, or f(a.value, b.value) otherwise.
*
* ReduceBySegmentOp is an associative, non-commutative binary combining operator
* for input sequences of cub::KeyValuePair pairings. Such
* sequences are typically used to represent a segmented set of values to be reduced
* and a corresponding set of {0,1}-valued integer "head flags" demarcating the
* first value of each segment.
*
*/
template ///< Binary reduction operator to apply to values
struct ReduceBySegmentOp
{
/// Wrapped reduction operator
ReductionOpT op;
/// Constructor
__host__ __device__ __forceinline__ ReduceBySegmentOp() {}
/// Constructor
__host__ __device__ __forceinline__ ReduceBySegmentOp(ReductionOpT op) : op(op) {}
/// Scan operator
template ///< KeyValuePair pairing of T (value) and OffsetT (head flag)
__host__ __device__ __forceinline__ KeyValuePairT operator()(
const KeyValuePairT &first, ///< First partial reduction
const KeyValuePairT &second) ///< Second partial reduction
{
KeyValuePairT retval;
retval.key = first.key + second.key;
retval.value = (second.key) ?
second.value : // The second partial reduction spans a segment reset, so it's value aggregate becomes the running aggregate
op(first.value, second.value); // The second partial reduction does not span a reset, so accumulate both into the running aggregate
return retval;
}
};
template ///< Binary reduction operator to apply to values
struct ReduceByKeyOp
{
/// Wrapped reduction operator
ReductionOpT op;
/// Constructor
__host__ __device__ __forceinline__ ReduceByKeyOp() {}
/// Constructor
__host__ __device__ __forceinline__ ReduceByKeyOp(ReductionOpT op) : op(op) {}
/// Scan operator
template
__host__ __device__ __forceinline__ KeyValuePairT operator()(
const KeyValuePairT &first, ///< First partial reduction
const KeyValuePairT &second) ///< Second partial reduction
{
KeyValuePairT retval = second;
if (first.key == second.key)
retval.value = op(first.value, retval.value);
return retval;
}
};
template
struct BinaryFlip
{
BinaryOpT binary_op;
__device__ __host__ explicit BinaryFlip(BinaryOpT binary_op)
: binary_op(binary_op)
{}
template
__device__ auto
operator()(T &&t, U &&u) -> decltype(binary_op(std::forward(u),
std::forward(t)))
{
return binary_op(std::forward(u), std::forward(t));
}
};
template
__device__ __host__ BinaryFlip MakeBinaryFlip(BinaryOpT binary_op)
{
return BinaryFlip(binary_op);
}
/** @} */ // end group UtilModule
CUB_NAMESPACE_END