/****************************************************************************** * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include namespace flash { using namespace cute; template __forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { tensor(mi, make_coord(j, nj)) = -INFINITY; } } } } } template __forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride, const int window_size_left, const int window_size_right) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } } // if (cute::thread0()) { // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); // print(tensor(make_coord(i, mi), _)); // // print(tensor(_, j + nj * size<1, 0>(tensor))); // } } } } template __forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride) { // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, max_seqlen_q, warp_row_stride, -1, 0); } template __forceinline__ __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 2, "Only support 2D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { tensor(mi, ni) = -INFINITY; } } // if (cute::thread0()) { // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); // print(tensor(_, make_coord(j, ni))); // // print(tensor(_, j + ni * size<1, 0>(tensor))); // } } } template struct Mask { const int max_seqlen_k, max_seqlen_q; const int window_size_left, window_size_right; const float alibi_slope; __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, const int window_size_left, const int window_size_right, const float alibi_slope=0.f) : max_seqlen_k(max_seqlen_k) , max_seqlen_q(max_seqlen_q) , window_size_left(window_size_left) , window_size_right(window_size_right) , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { }; // Causal_mask: whether this particular iteration needs causal masking template __forceinline__ __device__ void apply_mask(Tensor &tensor_, const int col_idx_offset_, const int row_idx_offset, const int warp_row_stride) { static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); static_assert(Layout::rank == 3, "Only support 3D Tensor"); static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } if constexpr (Need_masking) { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Col_idx_only) { #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { // No causal, no local if constexpr (Has_alibi) { tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } } } } } } else { #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { const int row_idx = row_idx_base + i * 8; const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { const int col_idx = col_idx_base + j; if constexpr (Has_alibi) { if constexpr (Is_causal) { tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; } else { tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); } } if constexpr (Causal_mask) { if (col_idx >= col_idx_limit_right) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } if constexpr (Is_local) { if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { // Causal and Local already handles MN masking if (col_idx >= max_seqlen_k) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } } } } } } } } }; }; } // namespace flash