// Copyright Supranational LLC // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 #ifndef __SPPARK_MSM_SORT_CUH__ #define __SPPARK_MSM_SORT_CUH__ /* * Custom sorting, we take in digits and return their indices. */ #define SORT_BLOCKDIM 1024 #ifndef DIGIT_BITS # define DIGIT_BITS 13 #endif #if DIGIT_BITS < 10 || DIGIT_BITS > 14 # error "impossible DIGIT_BITS" #endif __launch_bounds__(SORT_BLOCKDIM) __global__ void sort(vec2d_t inouts, size_t len, uint32_t win, vec2d_t temps, vec2d_t histograms, uint32_t wbits, uint32_t lsbits0, uint32_t lsbits1); #ifndef __MSM_SORT_DONT_IMPLEMENT__ #ifndef WARP_SZ # define WARP_SZ 32 #endif #ifdef __GNUC__ # define asm __asm__ __volatile__ #else # define asm asm volatile #endif static const uint32_t N_COUNTERS = 1<= 12 #pragma unroll for (uint32_t i = 0; i < N_SUMS/4; i++) ((uint4*)counters)[threadIdx.x + i*SORT_BLOCKDIM] = uint4{0, 0, 0, 0}; #else #pragma unroll for (uint32_t i = 0; i < N_SUMS; i++) counters[threadIdx.x + i*SORT_BLOCKDIM] = 0; #endif __syncthreads(); } __device__ __forceinline__ void count_digits(const uint32_t src[], uint32_t base, uint32_t len, uint32_t lshift, uint32_t rshift, uint32_t mask) { zero_counters(); const uint32_t pack_mask = 0xffffffffU << lshift; src += base; // count occurrences of each non-zero digit for (uint32_t i = threadIdx.x; i < len; i += SORT_BLOCKDIM) { auto val = src[(size_t)i]; auto pck = pack(base+i, pack_mask, (val-1) << lshift); if (val) (void)atomicAdd(&counters[(pck >> rshift) & mask], 1); } __syncthreads(); } __device__ __forceinline__ void scatter(uint2 dst[], const uint32_t src[], uint32_t base, uint32_t len, uint32_t lshift, uint32_t rshift, uint32_t mask, uint32_t pidx[] = nullptr) { const uint32_t pack_mask = 0xffffffffU << lshift; src += base; #pragma unroll 1 // the subroutine is memory-io-bound, unrolling makes no difference for (uint32_t i = threadIdx.x; i < len; i += SORT_BLOCKDIM) { auto val = src[(size_t)i]; auto pck = pack(base+i, pack_mask, (val-1) << lshift); if (val) { uint32_t idx = atomicSub(&counters[(pck >> rshift) & mask], 1) - 1; uint32_t pid = pidx ? pidx[base+i] : base+i; dst[idx] = uint2{pck, pack(pid, 0x80000000, val)}; } } } __device__ static void upper_sort(uint2 dst[], const uint32_t src[], uint32_t len, uint32_t lsbits, uint32_t bits, uint32_t digit, uint32_t histogram[]) { uint32_t grid_div = 31 - __clz(gridDim.x); uint32_t grid_rem = (1<> grid_div; // / gridDim.x; uint32_t rem = len & grid_rem; // % gridDim.x; uint32_t base; if (blockIdx.x < rem) base = ++slice * blockIdx.x; else base = slice * blockIdx.x + rem; const uint32_t mask = (1<> grid_div; // / gridDim.x; const uint32_t sub_laneid = laneid & grid_rem; // % gridDim.x; const uint32_t stride = WARP_SZ >> grid_div; // / gridDim.x; uint2 h = uint2{0, 0}; uint32_t sum, warp_off = warpid*WARP_SZ*N_SUMS + sub_warpid; #pragma unroll 1 for (uint32_t i = 0; i < WARP_SZ*N_SUMS; i += stride, warp_off += stride) { auto* hptr = &histogram[warp_off << digit]; sum = (warp_off < 1< 0) sum += __shfl_sync(0xffffffff, prefix_sums[i-1], WARP_SZ-1); prefix_sums[i] = sum; } // carry over most significant prefix sums from each warp if (laneid == WARP_SZ-1) counters[warpid*(WARP_SZ*N_SUMS+1)] = prefix_sums[N_SUMS-1]; __syncthreads(); uint32_t carry_sum = laneid ? counters[(laneid-1)*(WARP_SZ*N_SUMS+1)] : 0; __syncthreads(); carry_sum = sum_up(carry_sum, SORT_BLOCKDIM/WARP_SZ); carry_sum = __shfl_sync(0xffffffff, carry_sum, warpid); carry_sum += base; #pragma unroll for (uint32_t i = 0; i < N_SUMS; i++) counters[lane_off + i*WARP_SZ] = prefix_sums[i] += carry_sum; // store the prefix sums to histogram[] #pragma unroll for (uint32_t i = 0; i < N_SUMS; i++, lane_off += WARP_SZ) { if (lane_off < 1< DIGIT_BITS || (lg_gridDim && wbits > lg_gridDim+1)) { uint32_t top_bits = wbits / 2; uint32_t low_bits = wbits - top_bits; if (low_bits < lg_gridDim+1) { low_bits = lg_gridDim+1; top_bits = wbits - low_bits; } upper_sort(temp, inout, len, lsbits, top_bits, low_bits, histogram); histogram += blockIdx.x< inouts, size_t len, uint32_t win, vec2d_t temps, vec2d_t histograms, uint32_t wbits, uint32_t lsbits0, uint32_t lsbits1) { win += blockIdx.y; sort_row(inouts[win], len, temps[blockIdx.y], histograms[win], wbits, blockIdx.y==0 ? lsbits0 : lsbits1); } # undef asm #endif #endif