/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once #include #include #include #include #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include #endif #include #include #include #include #include #include //////////////////////////////////////////////////////////////////////////////////////////////////// namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ uint32_t relu2(const uint32_t x); template<> __forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); #else asm volatile( \ "{\n" \ "\t .reg .f16x2 sela;\n" \ "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ "\t and.b32 %0, sela, %1;\n" "}\n" : "=r"(res) : "r"(x), "r"(zero)); #endif return res; } #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template<> __forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); return res; } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template __forceinline__ __device__ uint32_t convert_relu2(const float2 x); template<> __forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); return res; } template<> __forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); return res; } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// template struct MaxOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster __device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { __device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template<> struct Allreduce<2> { template static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. template __forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); if constexpr (mma_shape_K == 8) { return acc_layout; } else { auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) template __forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { using X = Underscore; static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; // HACK: this requires tensor to be "contiguous" auto frag = convert_op(*reinterpret_cast *>(tensor.data())); return make_tensor(make_rmem_ptr(&frag), tensor.layout()); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void relu_(Tensor &tensor) { constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); using value_t = typename Engine::value_type; // HACK: this requires tensor to be "contiguous" Tensor tensor_uint32 = recast(tensor); #pragma unroll for (int i = 0; i < size(tensor_uint32); ++i) { tensor_uint32(i) = relu2(tensor_uint32(i)); } } //////////////////////////////////////////////////////////////////////////////////////////////////// // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction template __forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { using From_type = typename Engine::value_type; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // HACK: this requires tensor to be "contiguous" Tensor tensor_float2 = recast(tensor); Tensor out_uint32 = make_tensor(tensor_float2.layout()); #pragma unroll for (int i = 0; i < size(out_uint32); ++i) { out_uint32(i) = convert_relu2(tensor_float2(i)); } Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); #else Tensor out = flash::convert_type(tensor); flash::relu_(out); #endif return out; } //////////////////////////////////////////////////////////////////////////////////////////////////// // Blocks until all but N previous cp.async.commit_group operations have committed. // This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all // (which is equivalent to commit_group then wait_group 0). // Instead we just call cp.async.wait_group 0, which is slightly faster. // https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 template CUTE_HOST_DEVICE void cp_async_wait() { #if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); #endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // There's no case where !Clear_OOB_K && Clear_OOB_MN static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); #pragma unroll for (int m = 0; m < size<1>(S); ++m) { if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); } else if (Clear_OOB_K) { cute::clear(D(_, m, k)); } } } else if (Clear_OOB_MN) { cute::clear(D(_, m, _)); } } // TD [2023-04-13]: Strange that the code below can cause race condition. // I think it's because the copies are under an if statement. // if (Is_even_K) { // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // copy(tiled_copy, S(_, m, _), D(_, m, _)); // } else if (Clear_OOB_MN) { // clear(D(_, m, _)); // } // } // } else { // It's slightly faster in this case if iterate over K first // #pragma unroll // for (int k = 0; k < size<2>(S); ++k) { // if (predicate_K(k)) { // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // copy(tiled_copy, S(_, m, k), D(_, m, k)); // } else if (Clear_OOB_MN) { // clear(D(_, m, k)); // } // } // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN // if (Clear_OOB_MN || Is_even_MN) { // clear(D(_, _, k)); // } else { // #pragma unroll // for (int m = 0; m < size<1>(S); ++m) { // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { // clear(D(_, m, k)); // } // } // } // } // } // } } //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void copy_w_min_idx(Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, const int min_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } #pragma unroll for (int m = 0; m < size<1>(S); ++m) { // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } #pragma unroll for (int k = 0; k < size<2>(S); ++k) { if (Is_even_K || predicate_K(k)) { cute::copy(S(_, m, k), D(_, m, k)); } } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash