/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * cub::WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. */ #pragma once #include "../../config.cuh" #include "../../thread/thread_operators.cuh" #include "../../thread/thread_load.cuh" #include "../../thread/thread_store.cuh" #include "../../util_type.cuh" CUB_NAMESPACE_BEGIN /** * \brief WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. */ template < typename T, ///< Data type being scanned int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective struct WarpScanSmem { /****************************************************************************** * Constants and type definitions ******************************************************************************/ enum { /// Whether the logical warp size and the PTX warp size coincide IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), /// The number of warp scan steps STEPS = Log2::VALUE, /// The number of threads in half a warp HALF_WARP_THREADS = 1 << (STEPS - 1), /// The number of shared memory elements per warp WARP_SMEM_ELEMENTS = LOGICAL_WARP_THREADS + HALF_WARP_THREADS, }; /// Storage cell type (workaround for SM1x compiler bugs with custom-ops like Max() on signed chars) using CellT = T; /// Shared memory storage layout type (1.5 warps-worth of elements for each warp) typedef CellT _TempStorage[WARP_SMEM_ELEMENTS]; // Alias wrapper allowing storage to be unioned struct TempStorage : Uninitialized<_TempStorage> {}; /****************************************************************************** * Thread fields ******************************************************************************/ _TempStorage &temp_storage; unsigned int lane_id; unsigned int member_mask; /****************************************************************************** * Construction ******************************************************************************/ /// Constructor explicit __device__ __forceinline__ WarpScanSmem( TempStorage &temp_storage) : temp_storage(temp_storage.Alias()), lane_id(IS_ARCH_WARP ? LaneId() : LaneId() % LOGICAL_WARP_THREADS), member_mask( WarpMask( LaneId() / LOGICAL_WARP_THREADS)) {} /****************************************************************************** * Utility methods ******************************************************************************/ /// Basic inclusive scan iteration (template unrolled, inductive-case specialization) template < bool HAS_IDENTITY, int STEP, typename ScanOp> __device__ __forceinline__ void ScanStep( T &partial, ScanOp scan_op, Int2Type /*step*/) { const int OFFSET = 1 << STEP; // Share partial into buffer ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) partial); WARP_SYNC(member_mask); // Update partial if addend is in range if (HAS_IDENTITY || (lane_id >= OFFSET)) { T addend = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - OFFSET]); partial = scan_op(addend, partial); } WARP_SYNC(member_mask); ScanStep(partial, scan_op, Int2Type()); } /// Basic inclusive scan iteration(template unrolled, base-case specialization) template < bool HAS_IDENTITY, typename ScanOp> __device__ __forceinline__ void ScanStep( T &/*partial*/, ScanOp /*scan_op*/, Int2Type /*step*/) {} /// Inclusive prefix scan (specialized for summation across primitive types) __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. Sum scan_op, ///< [in] Binary scan operator Int2Type /*is_primitive*/) ///< [in] Marker type indicating whether T is primitive type { T identity = 0; ThreadStore(&temp_storage[lane_id], (CellT) identity); WARP_SYNC(member_mask); // Iterate scan steps output = input; ScanStep(output, scan_op, Int2Type<0>()); } /// Inclusive prefix scan template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op, ///< [in] Binary scan operator Int2Type /*is_primitive*/) ///< [in] Marker type indicating whether T is primitive type { // Iterate scan steps output = input; ScanStep(output, scan_op, Int2Type<0>()); } /****************************************************************************** * Interface ******************************************************************************/ //--------------------------------------------------------------------- // Broadcast //--------------------------------------------------------------------- /// Broadcast __device__ __forceinline__ T Broadcast( T input, ///< [in] The value to broadcast unsigned int src_lane) ///< [in] Which warp lane is to do the broadcasting { if (lane_id == src_lane) { ThreadStore(temp_storage, (CellT) input); } WARP_SYNC(member_mask); return (T)ThreadLoad(temp_storage); } //--------------------------------------------------------------------- // Inclusive operations //--------------------------------------------------------------------- /// Inclusive scan template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op) ///< [in] Binary scan operator { InclusiveScan(input, inclusive_output, scan_op, Int2Type::PRIMITIVE>()); } /// Inclusive scan with aggregate template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op, ///< [in] Binary scan operator T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. { InclusiveScan(input, inclusive_output, scan_op); // Retrieve aggregate ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive_output); WARP_SYNC(member_mask); warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); WARP_SYNC(member_mask); } //--------------------------------------------------------------------- // Get exclusive from inclusive //--------------------------------------------------------------------- /// Update inclusive and exclusive using input and inclusive template __device__ __forceinline__ void Update( T /*input*/, ///< [in] T &inclusive, ///< [in, out] T &exclusive, ///< [out] ScanOpT /*scan_op*/, ///< [in] IsIntegerT /*is_integer*/) ///< [in] { // initial value unknown ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); WARP_SYNC(member_mask); exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); } /// Update inclusive and exclusive using input and inclusive (specialized for summation of integer types) __device__ __forceinline__ void Update( T input, T &inclusive, T &exclusive, cub::Sum /*scan_op*/, Int2Type /*is_integer*/) { // initial value presumed 0 exclusive = inclusive - input; } /// Update inclusive and exclusive using initial value using input, inclusive, and initial value template __device__ __forceinline__ void Update ( T /*input*/, T &inclusive, T &exclusive, ScanOpT scan_op, T initial_value, IsIntegerT /*is_integer*/) { inclusive = scan_op(initial_value, inclusive); ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); WARP_SYNC(member_mask); exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); if (lane_id == 0) exclusive = initial_value; } /// Update inclusive and exclusive using initial value using input and inclusive (specialized for summation of integer types) __device__ __forceinline__ void Update ( T input, T &inclusive, T &exclusive, cub::Sum scan_op, T initial_value, Int2Type /*is_integer*/) { inclusive = scan_op(initial_value, inclusive); exclusive = inclusive - input; } /// Update inclusive, exclusive, and warp aggregate using input and inclusive template __device__ __forceinline__ void Update ( T /*input*/, T &inclusive, T &exclusive, T &warp_aggregate, ScanOpT /*scan_op*/, IsIntegerT /*is_integer*/) { // Initial value presumed to be unknown or identity (either way our padding is correct) ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); WARP_SYNC(member_mask); exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); } /// Update inclusive, exclusive, and warp aggregate using input and inclusive (specialized for summation of integer types) __device__ __forceinline__ void Update ( T input, T &inclusive, T &exclusive, T &warp_aggregate, cub::Sum /*scan_o*/, Int2Type /*is_integer*/) { // Initial value presumed to be unknown or identity (either way our padding is correct) ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); WARP_SYNC(member_mask); warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); exclusive = inclusive - input; } /// Update inclusive, exclusive, and warp aggregate using input, inclusive, and initial value template __device__ __forceinline__ void Update ( T /*input*/, T &inclusive, T &exclusive, T &warp_aggregate, ScanOpT scan_op, T initial_value, IsIntegerT /*is_integer*/) { // Broadcast warp aggregate ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive); WARP_SYNC(member_mask); warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); WARP_SYNC(member_mask); // Update inclusive with initial value inclusive = scan_op(initial_value, inclusive); // Get exclusive from exclusive ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id - 1], (CellT) inclusive); WARP_SYNC(member_mask); exclusive = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 2]); if (lane_id == 0) exclusive = initial_value; } }; CUB_NAMESPACE_END