/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * cub::AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide select. */ #pragma once #include #include "single_pass_scan_operators.cuh" #include "../block/block_load.cuh" #include "../block/block_store.cuh" #include "../block/block_scan.cuh" #include "../block/block_exchange.cuh" #include "../block/block_discontinuity.cuh" #include "../config.cuh" #include "../grid/grid_queue.cuh" #include "../iterator/cache_modified_input_iterator.cuh" CUB_NAMESPACE_BEGIN /****************************************************************************** * Tuning policy types ******************************************************************************/ /** * Parameterizable tuning policy type for AgentSelectIf */ template < int _BLOCK_THREADS, ///< Threads per thread block int _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) BlockLoadAlgorithm _LOAD_ALGORITHM, ///< The BlockLoad algorithm to use CacheLoadModifier _LOAD_MODIFIER, ///< Cache load modifier for reading input elements BlockScanAlgorithm _SCAN_ALGORITHM> ///< The BlockScan algorithm to use struct AgentSelectIfPolicy { enum { BLOCK_THREADS = _BLOCK_THREADS, ///< Threads per thread block ITEMS_PER_THREAD = _ITEMS_PER_THREAD, ///< Items per thread (per tile of input) }; static const BlockLoadAlgorithm LOAD_ALGORITHM = _LOAD_ALGORITHM; ///< The BlockLoad algorithm to use static const CacheLoadModifier LOAD_MODIFIER = _LOAD_MODIFIER; ///< Cache load modifier for reading input elements static const BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; ///< The BlockScan algorithm to use }; /****************************************************************************** * Thread block abstractions ******************************************************************************/ /** * \brief AgentSelectIf implements a stateful abstraction of CUDA thread blocks for participating in device-wide selection * * Performs functor-based selection if SelectOpT functor type != NullType * Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType * Otherwise performs discontinuity selection (keep unique) */ template < typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicy tuning policy type typename InputIteratorT, ///< Random-access input iterator type for selection items typename FlagsInputIteratorT, ///< Random-access input iterator type for selections (NullType* if a selection functor or discontinuity flagging is to be used for selection) typename SelectedOutputIteratorT, ///< Random-access output iterator type for selection_flags items typename SelectOpT, ///< Selection operator type (NullType if selections or discontinuity flagging is to be used for selection) typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selections is to be used for selection) typename OffsetT, ///< Signed integer type for global offsets bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output struct AgentSelectIf { //--------------------------------------------------------------------- // Types and constants //--------------------------------------------------------------------- // The input value type using InputT = cub::detail::value_t; // The output value type using OutputT = cub::detail::non_void_value_t; // The flag value type using FlagT = cub::detail::value_t; // Tile status descriptor interface type using ScanTileStateT = ScanTileState; // Constants enum { USE_SELECT_OP, USE_SELECT_FLAGS, USE_DISCONTINUITY, BLOCK_THREADS = AgentSelectIfPolicyT::BLOCK_THREADS, ITEMS_PER_THREAD = AgentSelectIfPolicyT::ITEMS_PER_THREAD, TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, TWO_PHASE_SCATTER = (ITEMS_PER_THREAD > 1), SELECT_METHOD = (!std::is_same::value) ? USE_SELECT_OP : (!std::is_same::value) ? USE_SELECT_FLAGS : USE_DISCONTINUITY }; // Cache-modified Input iterator wrapper type (for applying cache modifier) for items // Wrap the native input pointer with CacheModifiedValuesInputIterator // or directly use the supplied input iterator type using WrappedInputIteratorT = cub::detail::conditional_t< std::is_pointer::value, CacheModifiedInputIterator, InputIteratorT>; // Cache-modified Input iterator wrapper type (for applying cache modifier) for values // Wrap the native input pointer with CacheModifiedValuesInputIterator // or directly use the supplied input iterator type using WrappedFlagsInputIteratorT = cub::detail::conditional_t< std::is_pointer::value, CacheModifiedInputIterator, FlagsInputIteratorT>; // Parameterized BlockLoad type for input data using BlockLoadT = BlockLoad; // Parameterized BlockLoad type for flags using BlockLoadFlags = BlockLoad; // Parameterized BlockDiscontinuity type for items using BlockDiscontinuityT = BlockDiscontinuity; // Parameterized BlockScan type using BlockScanT = BlockScan; // Callback type for obtaining tile prefix during block scan using TilePrefixCallbackOpT = TilePrefixCallbackOp; // Item exchange type typedef OutputT ItemExchangeT[TILE_ITEMS]; // Shared memory type for this thread block union _TempStorage { struct ScanStorage { typename BlockScanT::TempStorage scan; // Smem needed for tile scanning typename TilePrefixCallbackOpT::TempStorage prefix; // Smem needed for cooperative prefix callback typename BlockDiscontinuityT::TempStorage discontinuity; // Smem needed for discontinuity detection } scan_storage; // Smem needed for loading items typename BlockLoadT::TempStorage load_items; // Smem needed for loading values typename BlockLoadFlags::TempStorage load_flags; // Smem needed for compacting items (allows non POD items in this union) Uninitialized raw_exchange; }; // Alias wrapper allowing storage to be unioned struct TempStorage : Uninitialized<_TempStorage> {}; //--------------------------------------------------------------------- // Per-thread fields //--------------------------------------------------------------------- _TempStorage& temp_storage; ///< Reference to temp_storage WrappedInputIteratorT d_in; ///< Input items SelectedOutputIteratorT d_selected_out; ///< Unique output items WrappedFlagsInputIteratorT d_flags_in; ///< Input selection flags (if applicable) InequalityWrapper inequality_op; ///< T inequality operator SelectOpT select_op; ///< Selection operator OffsetT num_items; ///< Total number of input items //--------------------------------------------------------------------- // Constructor //--------------------------------------------------------------------- // Constructor __device__ __forceinline__ AgentSelectIf( TempStorage &temp_storage, ///< Reference to temp_storage InputIteratorT d_in, ///< Input data FlagsInputIteratorT d_flags_in, ///< Input selection flags (if applicable) SelectedOutputIteratorT d_selected_out, ///< Output data SelectOpT select_op, ///< Selection operator EqualityOpT equality_op, ///< Equality operator OffsetT num_items) ///< Total number of input items : temp_storage(temp_storage.Alias()), d_in(d_in), d_flags_in(d_flags_in), d_selected_out(d_selected_out), select_op(select_op), inequality_op(equality_op), num_items(num_items) {} //--------------------------------------------------------------------- // Utility methods for initializing the selections //--------------------------------------------------------------------- /** * Initialize selections (specialized for selection operator) */ template __device__ __forceinline__ void InitializeSelections( OffsetT /*tile_offset*/, OffsetT num_tile_items, OutputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { // Out-of-bounds items are selection_flags selection_flags[ITEM] = 1; if (!IS_LAST_TILE || (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM < num_tile_items)) selection_flags[ITEM] = select_op(items[ITEM]); } } /** * Initialize selections (specialized for valid flags) */ template __device__ __forceinline__ void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, OutputT (&/*items*/)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { CTA_SYNC(); FlagT flags[ITEMS_PER_THREAD]; if (IS_LAST_TILE) { // Out-of-bounds items are selection_flags BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags, num_tile_items, 1); } else { BlockLoadFlags(temp_storage.load_flags).Load(d_flags_in + tile_offset, flags); } // Convert flag type to selection_flags type #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { selection_flags[ITEM] = flags[ITEM]; } } /** * Initialize selections (specialized for discontinuity detection) */ template __device__ __forceinline__ void InitializeSelections( OffsetT tile_offset, OffsetT num_tile_items, OutputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], Int2Type /*select_method*/) { if (IS_FIRST_TILE) { CTA_SYNC(); // Set head selection_flags. First tile sets the first flag for the first item BlockDiscontinuityT(temp_storage.scan_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op); } else { OutputT tile_predecessor; if (threadIdx.x == 0) tile_predecessor = d_in[tile_offset - 1]; CTA_SYNC(); BlockDiscontinuityT(temp_storage.scan_storage.discontinuity).FlagHeads(selection_flags, items, inequality_op, tile_predecessor); } // Set selection flags for out-of-bounds items #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { // Set selection_flags for out-of-bounds items if ((IS_LAST_TILE) && (OffsetT(threadIdx.x * ITEMS_PER_THREAD) + ITEM >= num_tile_items)) selection_flags[ITEM] = 1; } } //--------------------------------------------------------------------- // Scatter utility methods //--------------------------------------------------------------------- /** * Scatter flagged items to output offsets (specialized for direct scattering) */ template __device__ __forceinline__ void ScatterDirect( OutputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], OffsetT num_selections) { // Scatter flagged items #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { if (selection_flags[ITEM]) { if ((!IS_LAST_TILE) || selection_indices[ITEM] < num_selections) { d_selected_out[selection_indices[ITEM]] = items[ITEM]; } } } } /** * Scatter flagged items to output offsets (specialized for two-phase scattering) */ template __device__ __forceinline__ void ScatterTwoPhase( OutputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int /*num_tile_items*/, ///< Number of valid items in this tile int num_tile_selections, ///< Number of selections in this tile OffsetT num_selections_prefix, ///< Total number of selections prior to this tile OffsetT /*num_rejected_prefix*/, ///< Total number of rejections prior to this tile Int2Type /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition { CTA_SYNC(); // Compact and scatter items #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int local_scatter_offset = selection_indices[ITEM] - num_selections_prefix; if (selection_flags[ITEM]) { temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; } } CTA_SYNC(); for (int item = threadIdx.x; item < num_tile_selections; item += BLOCK_THREADS) { d_selected_out[num_selections_prefix + item] = temp_storage.raw_exchange.Alias()[item]; } } /** * Scatter flagged items to output offsets (specialized for two-phase scattering) */ template __device__ __forceinline__ void ScatterTwoPhase( OutputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, ///< Number of valid items in this tile int num_tile_selections, ///< Number of selections in this tile OffsetT num_selections_prefix, ///< Total number of selections prior to this tile OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile Int2Type /*is_keep_rejects*/) ///< Marker type indicating whether to keep rejected items in the second partition { CTA_SYNC(); int tile_num_rejections = num_tile_items - num_tile_selections; // Scatter items to shared memory (rejections first) #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM; int local_selection_idx = selection_indices[ITEM] - num_selections_prefix; int local_rejection_idx = item_idx - local_selection_idx; int local_scatter_offset = (selection_flags[ITEM]) ? tile_num_rejections + local_selection_idx : local_rejection_idx; temp_storage.raw_exchange.Alias()[local_scatter_offset] = items[ITEM]; } CTA_SYNC(); // Gather items from shared memory and scatter to global #pragma unroll for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM) { int item_idx = (ITEM * BLOCK_THREADS) + threadIdx.x; int rejection_idx = item_idx; int selection_idx = item_idx - tile_num_rejections; OffsetT scatter_offset = (item_idx < tile_num_rejections) ? num_items - num_rejected_prefix - rejection_idx - 1 : num_selections_prefix + selection_idx; OutputT item = temp_storage.raw_exchange.Alias()[item_idx]; if (!IS_LAST_TILE || (item_idx < num_tile_items)) { d_selected_out[scatter_offset] = item; } } } /** * Scatter flagged items */ template __device__ __forceinline__ void Scatter( OutputT (&items)[ITEMS_PER_THREAD], OffsetT (&selection_flags)[ITEMS_PER_THREAD], OffsetT (&selection_indices)[ITEMS_PER_THREAD], int num_tile_items, ///< Number of valid items in this tile int num_tile_selections, ///< Number of selections in this tile OffsetT num_selections_prefix, ///< Total number of selections prior to this tile OffsetT num_rejected_prefix, ///< Total number of rejections prior to this tile OffsetT num_selections) ///< Total number of selections including this tile { // Do a two-phase scatter if (a) keeping both partitions or (b) two-phase is enabled and the average number of selection_flags items per thread is greater than one if (KEEP_REJECTS || (TWO_PHASE_SCATTER && (num_tile_selections > BLOCK_THREADS))) { ScatterTwoPhase( items, selection_flags, selection_indices, num_tile_items, num_tile_selections, num_selections_prefix, num_rejected_prefix, Int2Type()); } else { ScatterDirect( items, selection_flags, selection_indices, num_selections); } } //--------------------------------------------------------------------- // Cooperatively scan a device-wide sequence of tiles with other CTAs //--------------------------------------------------------------------- /** * Process first tile of input (dynamic chained scan). Returns the running count of selections (including this tile) */ template __device__ __forceinline__ OffsetT ConsumeFirstTile( int num_tile_items, ///< Number of input items comprising this tile OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { OutputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; // Load items if (IS_LAST_TILE) BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); else BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); // Initialize selection_flags InitializeSelections( tile_offset, num_tile_items, items, selection_flags, Int2Type()); CTA_SYNC(); // Exclusive scan of selection_flags OffsetT num_tile_selections; BlockScanT(temp_storage.scan_storage.scan).ExclusiveSum(selection_flags, selection_indices, num_tile_selections); if (threadIdx.x == 0) { // Update tile status if this is not the last tile if (!IS_LAST_TILE) tile_state.SetInclusive(0, num_tile_selections); } // Discount any out-of-bounds selections if (IS_LAST_TILE) num_tile_selections -= (TILE_ITEMS - num_tile_items); // Scatter flagged items Scatter( items, selection_flags, selection_indices, num_tile_items, num_tile_selections, 0, 0, num_tile_selections); return num_tile_selections; } /** * Process subsequent tile of input (dynamic chained scan). Returns the running count of selections (including this tile) */ template __device__ __forceinline__ OffsetT ConsumeSubsequentTile( int num_tile_items, ///< Number of input items comprising this tile int tile_idx, ///< Tile index OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { OutputT items[ITEMS_PER_THREAD]; OffsetT selection_flags[ITEMS_PER_THREAD]; OffsetT selection_indices[ITEMS_PER_THREAD]; // Load items if (IS_LAST_TILE) BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items, num_tile_items); else BlockLoadT(temp_storage.load_items).Load(d_in + tile_offset, items); // Initialize selection_flags InitializeSelections( tile_offset, num_tile_items, items, selection_flags, Int2Type()); CTA_SYNC(); // Exclusive scan of values and selection_flags TilePrefixCallbackOpT prefix_op(tile_state, temp_storage.scan_storage.prefix, cub::Sum(), tile_idx); BlockScanT(temp_storage.scan_storage.scan).ExclusiveSum(selection_flags, selection_indices, prefix_op); OffsetT num_tile_selections = prefix_op.GetBlockAggregate(); OffsetT num_selections = prefix_op.GetInclusivePrefix(); OffsetT num_selections_prefix = prefix_op.GetExclusivePrefix(); OffsetT num_rejected_prefix = (tile_idx * TILE_ITEMS) - num_selections_prefix; // Discount any out-of-bounds selections if (IS_LAST_TILE) { int num_discount = TILE_ITEMS - num_tile_items; num_selections -= num_discount; num_tile_selections -= num_discount; } // Scatter flagged items Scatter( items, selection_flags, selection_indices, num_tile_items, num_tile_selections, num_selections_prefix, num_rejected_prefix, num_selections); return num_selections; } /** * Process a tile of input */ template __device__ __forceinline__ OffsetT ConsumeTile( int num_tile_items, ///< Number of input items comprising this tile int tile_idx, ///< Tile index OffsetT tile_offset, ///< Tile offset ScanTileStateT& tile_state) ///< Global tile state descriptor { OffsetT num_selections; if (tile_idx == 0) { num_selections = ConsumeFirstTile(num_tile_items, tile_offset, tile_state); } else { num_selections = ConsumeSubsequentTile(num_tile_items, tile_idx, tile_offset, tile_state); } return num_selections; } /** * Scan tiles of items as part of a dynamic chained scan */ template ///< Output iterator type for recording number of items selection_flags __device__ __forceinline__ void ConsumeRange( int num_tiles, ///< Total number of input tiles ScanTileStateT& tile_state, ///< Global tile state descriptor NumSelectedIteratorT d_num_selected_out) ///< Output total number selection_flags { // Blocks are launched in increasing order, so just assign one tile per block int tile_idx = (blockIdx.x * gridDim.y) + blockIdx.y; // Current tile index OffsetT tile_offset = tile_idx * TILE_ITEMS; // Global offset for the current tile if (tile_idx < num_tiles - 1) { // Not the last tile (full) ConsumeTile(TILE_ITEMS, tile_idx, tile_offset, tile_state); } else { // The last tile (possibly partially-full) OffsetT num_remaining = num_items - tile_offset; OffsetT num_selections = ConsumeTile(num_remaining, tile_idx, tile_offset, tile_state); if (threadIdx.x == 0) { // Output the total number of items selection_flags *d_num_selected_out = num_selections; } } } }; CUB_NAMESPACE_END