/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * cub::BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block */ #pragma once #include #include "../thread/thread_reduce.cuh" #include "../thread/thread_scan.cuh" #include "../block/block_scan.cuh" #include "../block/radix_rank_sort_operations.cuh" #include "../config.cuh" #include "../util_ptx.cuh" #include "../util_type.cuh" CUB_NAMESPACE_BEGIN /** * \brief Radix ranking algorithm, the algorithm used to implement stable ranking of the * keys from a single tile. Note that different ranking algorithms require different * initial arrangements of keys to function properly. */ enum RadixRankAlgorithm { /** Ranking using the BlockRadixRank algorithm with MEMOIZE_OUTER_SCAN == false. It * uses thread-private histograms, and thus uses more shared memory. Requires blocked * arrangement of keys. Does not support count callbacks. */ RADIX_RANK_BASIC, /** Ranking using the BlockRadixRank algorithm with MEMOIZE_OUTER_SCAN == * true. Similar to RADIX_RANK BASIC, it requires blocked arrangement of * keys and does not support count callbacks.*/ RADIX_RANK_MEMOIZE, /** Ranking using the BlockRadixRankMatch algorithm. It uses warp-private * histograms and matching for ranking the keys in a single warp. Therefore, * it uses less shared memory compared to RADIX_RANK_BASIC. It requires * warp-striped key arrangement and supports count callbacks. */ RADIX_RANK_MATCH, /** Ranking using the BlockRadixRankMatchEarlyCounts algorithm with * MATCH_ALGORITHM == WARP_MATCH_ANY. An alternative implementation of * match-based ranking that computes bin counts early. Because of this, it * works better with onesweep sorting, which requires bin counts for * decoupled look-back. Assumes warp-striped key arrangement and supports * count callbacks.*/ RADIX_RANK_MATCH_EARLY_COUNTS_ANY, /** Ranking using the BlockRadixRankEarlyCounts algorithm with * MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR. It uses extra space in shared * memory to generate warp match masks using atomicOr(). This is faster when * there are few matches, but can lead to slowdowns if the number of * matching keys among warp lanes is high. Assumes warp-striped key * arrangement and supports count callbacks. */ RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR }; /** Empty callback implementation */ template struct BlockRadixRankEmptyCallback { __device__ __forceinline__ void operator()(int (&bins)[BINS_PER_THREAD]) {} }; /** * \brief BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block. * \ingroup BlockModule * * \tparam BLOCK_DIM_X The thread block length in threads along the X dimension * \tparam RADIX_BITS The number of radix bits per digit place * \tparam IS_DESCENDING Whether or not the sorted-order is high-to-low * \tparam MEMOIZE_OUTER_SCAN [optional] Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure (default: true for architectures SM35 and newer, false otherwise). See BlockScanAlgorithm::BLOCK_SCAN_RAKING_MEMOIZE for more details. * \tparam INNER_SCAN_ALGORITHM [optional] The cub::BlockScanAlgorithm algorithm to use (default: cub::BLOCK_SCAN_WARP_SCANS) * \tparam SMEM_CONFIG [optional] Shared memory bank mode (default: \p cudaSharedMemBankSizeFourByte) * \tparam BLOCK_DIM_Y [optional] The thread block length in threads along the Y dimension (default: 1) * \tparam BLOCK_DIM_Z [optional] The thread block length in threads along the Z dimension (default: 1) * \tparam PTX_ARCH [optional] \ptxversion * * \par Overview * Blah... * - Keys must be in a form suitable for radix ranking (i.e., unsigned bits). * - \blocked * * \par Performance Considerations * - \granularity * * \par Examples * \par * - Example 1: Simple radix rank of 32-bit integer keys * \code * #include * * template * __global__ void ExampleKernel(...) * { * * \endcode * * \par Re-using dynamically allocating shared memory * The following example under the examples/block folder illustrates usage of * dynamically shared memory with BlockReduce and how to re-purpose * the same memory region: * example_block_reduce_dyn_smem.cu * * This example can be easily adapted to the storage required by BlockRadixRank. */ template < int BLOCK_DIM_X, int RADIX_BITS, bool IS_DESCENDING, bool MEMOIZE_OUTER_SCAN = (CUB_PTX_ARCH >= 350) ? true : false, BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, cudaSharedMemConfig SMEM_CONFIG = cudaSharedMemBankSizeFourByte, int BLOCK_DIM_Y = 1, int BLOCK_DIM_Z = 1, int PTX_ARCH = CUB_PTX_ARCH> class BlockRadixRank { private: /****************************************************************************** * Type definitions and constants ******************************************************************************/ // Integer type for digit counters (to be packed into words of type PackedCounters) typedef unsigned short DigitCounter; // Integer type for packing DigitCounters into columns of shared memory banks using PackedCounter = cub::detail::conditional_t; enum { // The thread block size in threads BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, RADIX_DIGITS = 1 << RADIX_BITS, LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), WARP_THREADS = 1 << LOG_WARP_THREADS, WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, BYTES_PER_COUNTER = sizeof(DigitCounter), LOG_BYTES_PER_COUNTER = Log2::VALUE, PACKING_RATIO = static_cast(sizeof(PackedCounter) / sizeof(DigitCounter)), LOG_PACKING_RATIO = Log2::VALUE, LOG_COUNTER_LANES = CUB_MAX((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0), // Always at least one lane COUNTER_LANES = 1 << LOG_COUNTER_LANES, // The number of packed counters per thread (plus one for padding) PADDED_COUNTER_LANES = COUNTER_LANES + 1, RAKING_SEGMENT = PADDED_COUNTER_LANES, }; public: enum { /// Number of bin-starting offsets tracked per thread BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), }; private: /// BlockScan type typedef BlockScan< PackedCounter, BLOCK_DIM_X, INNER_SCAN_ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> BlockScan; /// Shared memory storage layout type for BlockRadixRank struct __align__(16) _TempStorage { union Aliasable { DigitCounter digit_counters[PADDED_COUNTER_LANES][BLOCK_THREADS][PACKING_RATIO]; PackedCounter raking_grid[BLOCK_THREADS][RAKING_SEGMENT]; } aliasable; // Storage for scanning local ranks typename BlockScan::TempStorage block_scan; }; /****************************************************************************** * Thread fields ******************************************************************************/ /// Shared storage reference _TempStorage &temp_storage; /// Linear thread-id unsigned int linear_tid; /// Copy of raking segment, promoted to registers PackedCounter cached_segment[RAKING_SEGMENT]; /****************************************************************************** * Utility methods ******************************************************************************/ /** * Internal storage allocator */ __device__ __forceinline__ _TempStorage& PrivateStorage() { __shared__ _TempStorage private_storage; return private_storage; } /** * Performs upsweep raking reduction, returning the aggregate */ __device__ __forceinline__ PackedCounter Upsweep() { PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; PackedCounter *raking_ptr; if (MEMOIZE_OUTER_SCAN) { // Copy data into registers #pragma unroll for (int i = 0; i < RAKING_SEGMENT; i++) { cached_segment[i] = smem_raking_ptr[i]; } raking_ptr = cached_segment; } else { raking_ptr = smem_raking_ptr; } return internal::ThreadReduce(raking_ptr, Sum()); } /// Performs exclusive downsweep raking scan __device__ __forceinline__ void ExclusiveDownsweep( PackedCounter raking_partial) { PackedCounter *smem_raking_ptr = temp_storage.aliasable.raking_grid[linear_tid]; PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ? cached_segment : smem_raking_ptr; // Exclusive raking downsweep scan internal::ThreadScanExclusive(raking_ptr, raking_ptr, Sum(), raking_partial); if (MEMOIZE_OUTER_SCAN) { // Copy data back to smem #pragma unroll for (int i = 0; i < RAKING_SEGMENT; i++) { smem_raking_ptr[i] = cached_segment[i]; } } } /** * Reset shared memory digit counters */ __device__ __forceinline__ void ResetCounters() { // Reset shared memory digit counters #pragma unroll for (int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++) { *((PackedCounter*) temp_storage.aliasable.digit_counters[LANE][linear_tid]) = 0; } } /** * Block-scan prefix callback */ struct PrefixCallBack { __device__ __forceinline__ PackedCounter operator()(PackedCounter block_aggregate) { PackedCounter block_prefix = 0; // Propagate totals in packed fields #pragma unroll for (int PACKED = 1; PACKED < PACKING_RATIO; PACKED++) { block_prefix += block_aggregate << (sizeof(DigitCounter) * 8 * PACKED); } return block_prefix; } }; /** * Scan shared memory digit counters. */ __device__ __forceinline__ void ScanCounters() { // Upsweep scan PackedCounter raking_partial = Upsweep(); // Compute exclusive sum PackedCounter exclusive_partial; PrefixCallBack prefix_call_back; BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back); // Downsweep scan with exclusive partial ExclusiveDownsweep(exclusive_partial); } public: /// \smemstorage{BlockScan} struct TempStorage : Uninitialized<_TempStorage> {}; /******************************************************************//** * \name Collective constructors *********************************************************************/ //@{ /** * \brief Collective constructor using a private static allocation of shared memory as temporary storage. */ __device__ __forceinline__ BlockRadixRank() : temp_storage(PrivateStorage()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) {} /** * \brief Collective constructor using the specified memory allocation as temporary storage. */ __device__ __forceinline__ BlockRadixRank( TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage : temp_storage(temp_storage.Alias()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) {} //@} end member group /******************************************************************//** * \name Raking *********************************************************************/ //@{ /** * \brief Rank keys. */ template < typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile DigitExtractorT digit_extractor) ///< [in] The digit extractor { DigitCounter thread_prefixes[KEYS_PER_THREAD]; // For each key, the count of previous keys in this tile having the same digit DigitCounter* digit_counters[KEYS_PER_THREAD]; // For each key, the byte-offset of its corresponding digit counter in smem // Reset shared memory digit counters ResetCounters(); #pragma unroll for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) { // Get digit unsigned int digit = digit_extractor.Digit(keys[ITEM]); // Get sub-counter unsigned int sub_counter = digit >> LOG_COUNTER_LANES; // Get counter lane unsigned int counter_lane = digit & (COUNTER_LANES - 1); if (IS_DESCENDING) { sub_counter = PACKING_RATIO - 1 - sub_counter; counter_lane = COUNTER_LANES - 1 - counter_lane; } // Pointer to smem digit counter digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane][linear_tid][sub_counter]; // Load thread-exclusive prefix thread_prefixes[ITEM] = *digit_counters[ITEM]; // Store inclusive prefix *digit_counters[ITEM] = thread_prefixes[ITEM] + 1; } CTA_SYNC(); // Scan shared memory counters ScanCounters(); CTA_SYNC(); // Extract the local ranks of each key #pragma unroll for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) { // Add in thread block exclusive prefix ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM]; } } /** * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. */ template < typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) DigitExtractorT digit_extractor, ///< [in] The digit extractor int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] { // Rank keys RankKeys(keys, ranks, digit_extractor); // Get the inclusive and exclusive digit totals corresponding to the calling thread. #pragma unroll for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) { int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) { if (IS_DESCENDING) bin_idx = RADIX_DIGITS - bin_idx - 1; // Obtain ex/inclusive digit counts. (Unfortunately these all reside in the // first counter column, resulting in unavoidable bank conflicts.) unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1)); unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES); exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counters[counter_lane][0][sub_counter]; } } } }; /** * Radix-rank using match.any */ template < int BLOCK_DIM_X, int RADIX_BITS, bool IS_DESCENDING, BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS, int BLOCK_DIM_Y = 1, int BLOCK_DIM_Z = 1, int PTX_ARCH = CUB_PTX_ARCH> class BlockRadixRankMatch { private: /****************************************************************************** * Type definitions and constants ******************************************************************************/ typedef int32_t RankT; typedef int32_t DigitCounterT; enum { // The thread block size in threads BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z, RADIX_DIGITS = 1 << RADIX_BITS, LOG_WARP_THREADS = CUB_LOG_WARP_THREADS(PTX_ARCH), WARP_THREADS = 1 << LOG_WARP_THREADS, WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS, PADDED_WARPS = ((WARPS & 0x1) == 0) ? WARPS + 1 : WARPS, COUNTERS = PADDED_WARPS * RADIX_DIGITS, RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS, PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ? RAKING_SEGMENT + 1 : RAKING_SEGMENT, }; public: enum { /// Number of bin-starting offsets tracked per thread BINS_TRACKED_PER_THREAD = CUB_MAX(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS), }; private: /// BlockScan type typedef BlockScan< DigitCounterT, BLOCK_THREADS, INNER_SCAN_ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> BlockScanT; /// Shared memory storage layout type for BlockRadixRank struct __align__(16) _TempStorage { typename BlockScanT::TempStorage block_scan; union __align__(16) Aliasable { volatile DigitCounterT warp_digit_counters[RADIX_DIGITS][PADDED_WARPS]; DigitCounterT raking_grid[BLOCK_THREADS][PADDED_RAKING_SEGMENT]; } aliasable; }; /****************************************************************************** * Thread fields ******************************************************************************/ /// Shared storage reference _TempStorage &temp_storage; /// Linear thread-id unsigned int linear_tid; public: /// \smemstorage{BlockScan} struct TempStorage : Uninitialized<_TempStorage> {}; /******************************************************************//** * \name Collective constructors *********************************************************************/ //@{ /** * \brief Collective constructor using the specified memory allocation as temporary storage. */ __device__ __forceinline__ BlockRadixRankMatch( TempStorage &temp_storage) ///< [in] Reference to memory allocation having layout type TempStorage : temp_storage(temp_storage.Alias()), linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z)) {} //@} end member group /******************************************************************//** * \name Raking *********************************************************************/ //@{ /** \brief Computes the count of keys for each digit value, and calls the * callback with the array of key counts. * @tparam CountsCallback The callback type. It should implement an instance * overload of operator()(int (&bins)[BINS_TRACKED_PER_THREAD]), where bins * is an array of key counts for each digit value distributed in block * distribution among the threads of the thread block. Key counts can be * used, to update other data structures in global or shared * memory. Depending on the implementation of the ranking algoirhtm * (see BlockRadixRankMatchEarlyCounts), key counts may become available * early, therefore, they are returned through a callback rather than a * separate output parameter of RankKeys(). */ template __device__ __forceinline__ void CallBack(CountsCallback callback) { int bins[BINS_TRACKED_PER_THREAD]; // Get count for each digit #pragma unroll for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) { int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; const int TILE_ITEMS = KEYS_PER_THREAD * BLOCK_THREADS; if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) { if (IS_DESCENDING) { bin_idx = RADIX_DIGITS - bin_idx - 1; bins[track] = (bin_idx > 0 ? temp_storage.aliasable.warp_digit_counters[bin_idx - 1][0] : TILE_ITEMS) - temp_storage.aliasable.warp_digit_counters[bin_idx][0]; } else { bins[track] = (bin_idx < RADIX_DIGITS - 1 ? temp_storage.aliasable.warp_digit_counters[bin_idx + 1][0] : TILE_ITEMS) - temp_storage.aliasable.warp_digit_counters[bin_idx][0]; } } } callback(bins); } /** * \brief Rank keys. */ template < typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, typename CountsCallback> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile DigitExtractorT digit_extractor, ///< [in] The digit extractor CountsCallback callback) { // Initialize shared digit counters #pragma unroll for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) temp_storage.aliasable.raking_grid[linear_tid][ITEM] = 0; CTA_SYNC(); // Each warp will strip-mine its section of input, one strip at a time volatile DigitCounterT *digit_counters[KEYS_PER_THREAD]; uint32_t warp_id = linear_tid >> LOG_WARP_THREADS; uint32_t lane_mask_lt = LaneMaskLt(); #pragma unroll for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) { // My digit uint32_t digit = digit_extractor.Digit(keys[ITEM]); if (IS_DESCENDING) digit = RADIX_DIGITS - digit - 1; // Mask of peers who have same digit as me uint32_t peer_mask = MatchAny(digit); // Pointer to smem digit counter for this key digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit][warp_id]; // Number of occurrences in previous strips DigitCounterT warp_digit_prefix = *digit_counters[ITEM]; // Warp-sync WARP_SYNC(0xFFFFFFFF); // Number of peers having same digit as me int32_t digit_count = __popc(peer_mask); // Number of lower-ranked peers having same digit seen so far int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt); if (peer_digit_prefix == 0) { // First thread for each digit updates the shared warp counter *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count); } // Warp-sync WARP_SYNC(0xFFFFFFFF); // Number of prior keys having same digit ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix); } CTA_SYNC(); // Scan warp counters DigitCounterT scan_counters[PADDED_RAKING_SEGMENT]; #pragma unroll for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid][ITEM]; BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters); #pragma unroll for (int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM) temp_storage.aliasable.raking_grid[linear_tid][ITEM] = scan_counters[ITEM]; CTA_SYNC(); if (!std::is_same< CountsCallback, BlockRadixRankEmptyCallback>::value) { CallBack(callback); } // Seed ranks with counter values from previous warps #pragma unroll for (int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM) ranks[ITEM] += *digit_counters[ITEM]; } template < typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor) { RankKeys(keys, ranks, digit_extractor, BlockRadixRankEmptyCallback()); } /** * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. */ template < typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT, typename CountsCallback> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) DigitExtractorT digit_extractor, ///< [in] The digit extractor int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD], ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] CountsCallback callback) { RankKeys(keys, ranks, digit_extractor, callback); // Get exclusive count for each digit #pragma unroll for (int track = 0; track < BINS_TRACKED_PER_THREAD; ++track) { int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track; if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS)) { if (IS_DESCENDING) bin_idx = RADIX_DIGITS - bin_idx - 1; exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx][0]; } } } template < typename UnsignedBits, int KEYS_PER_THREAD, typename DigitExtractorT> __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], ///< [in] Keys for this tile int (&ranks)[KEYS_PER_THREAD], ///< [out] For each key, the local rank within the tile (out parameter) DigitExtractorT digit_extractor, int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD]) ///< [out] The exclusive prefix sum for the digits [(threadIdx.x * BINS_TRACKED_PER_THREAD) ... (threadIdx.x * BINS_TRACKED_PER_THREAD) + BINS_TRACKED_PER_THREAD - 1] { RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix, BlockRadixRankEmptyCallback()); } }; enum WarpMatchAlgorithm { WARP_MATCH_ANY, WARP_MATCH_ATOMIC_OR }; /** * Radix-rank using matching which computes the counts of keys for each digit * value early, at the expense of doing more work. This may be useful e.g. for * decoupled look-back, where it reduces the time other thread blocks need to * wait for digit counts to become available. */ template struct BlockRadixRankMatchEarlyCounts { // constants enum { BLOCK_THREADS = BLOCK_DIM_X, RADIX_DIGITS = 1 << RADIX_BITS, BINS_PER_THREAD = (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS, BINS_TRACKED_PER_THREAD = BINS_PER_THREAD, FULL_BINS = BINS_PER_THREAD * BLOCK_THREADS == RADIX_DIGITS, WARP_THREADS = CUB_PTX_WARP_THREADS, BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS, WARP_MASK = ~0, NUM_MATCH_MASKS = MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR ? BLOCK_WARPS : 0, // Guard against declaring zero-sized array: MATCH_MASKS_ALLOC_SIZE = NUM_MATCH_MASKS < 1 ? 1 : NUM_MATCH_MASKS, }; // types typedef cub::BlockScan BlockScan; // temporary storage struct TempStorage { union { int warp_offsets[BLOCK_WARPS][RADIX_DIGITS]; int warp_histograms[BLOCK_WARPS][RADIX_DIGITS][NUM_PARTS]; }; int match_masks[MATCH_MASKS_ALLOC_SIZE][RADIX_DIGITS]; typename BlockScan::TempStorage prefix_tmp; }; TempStorage& temp_storage; // internal ranking implementation template struct BlockRadixRankMatchInternal { TempStorage& s; DigitExtractorT digit_extractor; CountsCallback callback; int warp; int lane; __device__ __forceinline__ int Digit(UnsignedBits key) { int digit = digit_extractor.Digit(key); return IS_DESCENDING ? RADIX_DIGITS - 1 - digit : digit; } __device__ __forceinline__ int ThreadBin(int u) { int bin = threadIdx.x * BINS_PER_THREAD + u; return IS_DESCENDING ? RADIX_DIGITS - 1 - bin : bin; } __device__ __forceinline__ void ComputeHistogramsWarp(UnsignedBits (&keys)[KEYS_PER_THREAD]) { //int* warp_offsets = &s.warp_offsets[warp][0]; int (&warp_histograms)[RADIX_DIGITS][NUM_PARTS] = s.warp_histograms[warp]; // compute warp-private histograms #pragma unroll for (int bin = lane; bin < RADIX_DIGITS; bin += WARP_THREADS) { #pragma unroll for (int part = 0; part < NUM_PARTS; ++part) { warp_histograms[bin][part] = 0; } } if (MATCH_ALGORITHM == WARP_MATCH_ATOMIC_OR) { int* match_masks = &s.match_masks[warp][0]; #pragma unroll for (int bin = lane; bin < RADIX_DIGITS; bin += WARP_THREADS) { match_masks[bin] = 0; } } WARP_SYNC(WARP_MASK); // compute private per-part histograms int part = lane % NUM_PARTS; #pragma unroll for (int u = 0; u < KEYS_PER_THREAD; ++u) { atomicAdd(&warp_histograms[Digit(keys[u])][part], 1); } // sum different parts; // no extra work is necessary if NUM_PARTS == 1 if (NUM_PARTS > 1) { WARP_SYNC(WARP_MASK); // TODO: handle RADIX_DIGITS % WARP_THREADS != 0 if it becomes necessary const int WARP_BINS_PER_THREAD = RADIX_DIGITS / WARP_THREADS; int bins[WARP_BINS_PER_THREAD]; #pragma unroll for (int u = 0; u < WARP_BINS_PER_THREAD; ++u) { int bin = lane + u * WARP_THREADS; bins[u] = internal::ThreadReduce(warp_histograms[bin], Sum()); } CTA_SYNC(); // store the resulting histogram in shared memory int* warp_offsets = &s.warp_offsets[warp][0]; #pragma unroll for (int u = 0; u < WARP_BINS_PER_THREAD; ++u) { int bin = lane + u * WARP_THREADS; warp_offsets[bin] = bins[u]; } } } __device__ __forceinline__ void ComputeOffsetsWarpUpsweep(int (&bins)[BINS_PER_THREAD]) { // sum up warp-private histograms #pragma unroll for (int u = 0; u < BINS_PER_THREAD; ++u) { bins[u] = 0; int bin = ThreadBin(u); if (FULL_BINS || (bin >= 0 && bin < RADIX_DIGITS)) { #pragma unroll for (int j_warp = 0; j_warp < BLOCK_WARPS; ++j_warp) { int warp_offset = s.warp_offsets[j_warp][bin]; s.warp_offsets[j_warp][bin] = bins[u]; bins[u] += warp_offset; } } } } __device__ __forceinline__ void ComputeOffsetsWarpDownsweep(int (&offsets)[BINS_PER_THREAD]) { #pragma unroll for (int u = 0; u < BINS_PER_THREAD; ++u) { int bin = ThreadBin(u); if (FULL_BINS || (bin >= 0 && bin < RADIX_DIGITS)) { int digit_offset = offsets[u]; #pragma unroll for (int j_warp = 0; j_warp < BLOCK_WARPS; ++j_warp) { s.warp_offsets[j_warp][bin] += digit_offset; } } } } __device__ __forceinline__ void ComputeRanksItem( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], Int2Type) { // compute key ranks int lane_mask = 1 << lane; int* warp_offsets = &s.warp_offsets[warp][0]; int* match_masks = &s.match_masks[warp][0]; #pragma unroll for (int u = 0; u < KEYS_PER_THREAD; ++u) { int bin = Digit(keys[u]); int* p_match_mask = &match_masks[bin]; atomicOr(p_match_mask, lane_mask); WARP_SYNC(WARP_MASK); int bin_mask = *p_match_mask; int leader = (WARP_THREADS - 1) - __clz(bin_mask); int warp_offset = 0; int popc = __popc(bin_mask & LaneMaskLe()); if (lane == leader) { // atomic is a bit faster warp_offset = atomicAdd(&warp_offsets[bin], popc); } warp_offset = SHFL_IDX_SYNC(warp_offset, leader, bin_mask); if (lane == leader) *p_match_mask = 0; WARP_SYNC(WARP_MASK); ranks[u] = warp_offset + popc - 1; } } __device__ __forceinline__ void ComputeRanksItem( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], Int2Type) { // compute key ranks int* warp_offsets = &s.warp_offsets[warp][0]; #pragma unroll for (int u = 0; u < KEYS_PER_THREAD; ++u) { int bin = Digit(keys[u]); int bin_mask = MatchAny(bin); int leader = (WARP_THREADS - 1) - __clz(bin_mask); int warp_offset = 0; int popc = __popc(bin_mask & LaneMaskLe()); if (lane == leader) { // atomic is a bit faster warp_offset = atomicAdd(&warp_offsets[bin], popc); } warp_offset = SHFL_IDX_SYNC(warp_offset, leader, bin_mask); ranks[u] = warp_offset + popc - 1; } } __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], int (&exclusive_digit_prefix)[BINS_PER_THREAD]) { ComputeHistogramsWarp(keys); CTA_SYNC(); int bins[BINS_PER_THREAD]; ComputeOffsetsWarpUpsweep(bins); callback(bins); BlockScan(s.prefix_tmp).ExclusiveSum(bins, exclusive_digit_prefix); ComputeOffsetsWarpDownsweep(exclusive_digit_prefix); CTA_SYNC(); ComputeRanksItem(keys, ranks, Int2Type()); } __device__ __forceinline__ BlockRadixRankMatchInternal (TempStorage& temp_storage, DigitExtractorT digit_extractor, CountsCallback callback) : s(temp_storage), digit_extractor(digit_extractor), callback(callback), warp(threadIdx.x / WARP_THREADS), lane(LaneId()) {} }; __device__ __forceinline__ BlockRadixRankMatchEarlyCounts (TempStorage& temp_storage) : temp_storage(temp_storage) {} /** * \brief Rank keys. For the lower \p RADIX_DIGITS threads, digit counts for each digit are provided for the corresponding thread. */ template __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int (&exclusive_digit_prefix)[BINS_PER_THREAD], CountsCallback callback) { BlockRadixRankMatchInternal internal(temp_storage, digit_extractor, callback); internal.RankKeys(keys, ranks, exclusive_digit_prefix); } template __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int (&exclusive_digit_prefix)[BINS_PER_THREAD]) { typedef BlockRadixRankEmptyCallback CountsCallback; BlockRadixRankMatchInternal internal(temp_storage, digit_extractor, CountsCallback()); internal.RankKeys(keys, ranks, exclusive_digit_prefix); } template __device__ __forceinline__ void RankKeys( UnsignedBits (&keys)[KEYS_PER_THREAD], int (&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor) { int exclusive_digit_prefix[BINS_PER_THREAD]; RankKeys(keys, ranks, digit_extractor, exclusive_digit_prefix); } }; CUB_NAMESPACE_END