################################################################################################# # # Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Data types and tags used for emitting CUTLASS C++ kernels """ import enum import re # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. # # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 def enum_auto() -> int: global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i ################################################################################################### # class GeneratorTarget(enum.Enum): Library = enum_auto() # GeneratorTargetNames = { GeneratorTarget.Library: 'library' } # ################################################################################################### # class DataType(enum.Enum): void = enum_auto() # primarily used to disable C tensor for epilogues b1 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() e4m3 = enum_auto() e5m2 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() tf32 = enum_auto() f64 = enum_auto() cf16 = enum_auto() cbf16 = enum_auto() cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() cu32 = enum_auto() cu64 = enum_auto() invalid = enum_auto() # ShortDataTypeNames = { DataType.s32: 'i', DataType.e4m3: 'e4m3', DataType.e5m2: 'e5m2', DataType.f16: 'h', DataType.f32: 's', DataType.f64: 'd', DataType.cf32: 'c', DataType.cf64: 'z', } # DataTypeNames = { DataType.void: "void", DataType.b1: "b1", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", DataType.e4m3: 'e4m3', DataType.e5m2: 'e5m2', DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", DataType.cbf16: "cbf16", DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", DataType.cs32: "cs32", DataType.cs64: "cs64", } DataTypeTag = { DataType.void: "void", DataType.b1: "cutlass::uint1b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.e4m3: 'cutlass::float_e4m3_t', DataType.e5m2: 'cutlass::float_e5m2_t', DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex", DataType.cbf16: "cutlass::complex", DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", DataType.cs32: "cutlass::complex", DataType.cs64: "cutlass::complex", } DataTypeSize = { DataType.void: 0, DataType.b1: 1, DataType.u4: 4, DataType.u8: 8, DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, DataType.e4m3: 8, DataType.e5m2: 8, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, DataType.cbf16: 32, DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, DataType.cs32: 64, DataType.cs64: 128, } ################################################################################################### # class BlasMode(enum.Enum): symmetric = enum_auto() hermitian = enum_auto() # BlasModeTag = { BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', } # class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() # ComplexTransformTag = { ComplexTransform.none: 'cutlass::ComplexTransform::kNone', ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', } # RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), (DataType.f64, DataType.cf64), ] # def is_complex(data_type): for r, c in RealComplexBijection: if data_type == c: return True return False # def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: return c return DataType.invalid # def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: return r return DataType.invalid # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() ################################################################################################### # class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() multiply_add_mixed_input_upcast = enum_auto() xor_popc = enum_auto() and_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_fast_f32 = enum_auto() multiply_add_complex_fast_f32 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() # MathOperationTag = { MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', MathOperation.and_popc: 'cutlass::arch::OpAndPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', } ################################################################################################### # class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() ColumnMajorInterleaved2 = enum_auto() RowMajorInterleaved2 = enum_auto() ColumnMajorInterleaved32 = enum_auto() RowMajorInterleaved32 = enum_auto() ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNHWC = enum_auto() TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNC32HW32 = enum_auto() TensorNC64HW64 = enum_auto() TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() # LayoutTag = { LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', LayoutType.RowMajor: 'cutlass::layout::RowMajor', LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', } # TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC } # ShortLayoutTypeNames = { LayoutType.ColumnMajor: 'n', LayoutType.ColumnMajorInterleaved2: 'n2', LayoutType.ColumnMajorInterleaved32: 'n32', LayoutType.ColumnMajorInterleaved64: 'n64', LayoutType.RowMajor: 't', LayoutType.RowMajorInterleaved2: 't2', LayoutType.RowMajorInterleaved32: 't32', LayoutType.RowMajorInterleaved64: 't64', LayoutType.TensorNHWC: 'nhwc', LayoutType.TensorNDHWC: 'ndhwc', LayoutType.TensorNCHW: 'nchw', LayoutType.TensorNGHWC: 'nghwc', LayoutType.TensorNC32HW32: 'nc32hw32', LayoutType.TensorNC64HW64: 'nc64hw64', LayoutType.TensorC32RSK32: 'c32rsk32', LayoutType.TensorC64RSK64: 'c64rsk64' } # ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): 'n', (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', (LayoutType.RowMajor, ComplexTransform.none): 't', (LayoutType.RowMajor, ComplexTransform.conj): 'h' } ################################################################################################### class KernelScheduleType(enum.Enum): ScheduleAuto = enum_auto() Multistage = enum_auto() CpAsyncWarpSpecialized = enum_auto() CpAsyncWarpSpecializedPingpong = enum_auto() CpAsyncWarpSpecializedCooperative = enum_auto() Tma = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedPingpong = enum_auto() TmaWarpSpecializedCooperative = enum_auto() TmaWarpSpecializedFP8FastAccum = enum_auto() TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() # KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', } # KernelScheduleSuffixes = { KernelScheduleType.ScheduleAuto: '', KernelScheduleType.Multistage: '_cpasync', KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', KernelScheduleType.Tma: '_unspecialized', KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', } class EpilogueScheduleType(enum.Enum): ScheduleAuto = enum_auto() EpilogueTransposed = enum_auto() NoSmemWarpSpecialized = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedCooperative = enum_auto() # EpilogueScheduleTag = { EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', } # EpilogueScheduleSuffixes = { EpilogueScheduleType.ScheduleAuto: '', EpilogueScheduleType.EpilogueTransposed: '', EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', } class EpilogueFunctor3x(enum.Enum): LinearCombination = enum_auto() # EpilogueFunctor3xTag = { EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', } class TileSchedulerType(enum.Enum): Default = enum_auto() Persistent = enum_auto() StreamK = enum_auto() # TileSchedulerTag = { TileSchedulerType.Default: 'void', TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', } # TileSchedulerSuffixes = { TileSchedulerType.Default: '', TileSchedulerType.Persistent: '', TileSchedulerType.StreamK: '_stream_k', } ################################################################################################### # class SideMode(enum.Enum): Left = enum_auto() Right = enum_auto() # SideModeTag = { SideMode.Left: 'cutlass::SideMode::kLeft', SideMode.Right: 'cutlass::SideMode::kRight' } # ShortSideModeNames = { SideMode.Left: 'ls', SideMode.Right: 'rs' } ################################################################################################### # class FillMode(enum.Enum): Lower = enum_auto() Upper = enum_auto() # FillModeTag = { FillMode.Lower: 'cutlass::FillMode::kLower', FillMode.Upper: 'cutlass::FillMode::kUpper' } # ShortFillModeNames = { FillMode.Lower: 'l', FillMode.Upper: 'u' } ################################################################################################### # class DiagType(enum.Enum): NonUnit = enum_auto() Unit = enum_auto() # DiagTypeTag = { DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', DiagType.Unit: 'cutlass::DiagType::kUnit' } # ShortDiagTypeNames = { DiagType.NonUnit: 'nu', DiagType.Unit: 'un' } ################################################################################################### # class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: 'simt', OpcodeClass.TensorOp: 'tensorop', OpcodeClass.WmmaTensorOp: 'wmma_tensorop', } OpcodeClassTag = { OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', } ################################################################################################### # class OperationKind(enum.Enum): Gemm = enum_auto() RankK = enum_auto() Rank2K = enum_auto() Trmm = enum_auto() Symm = enum_auto() Conv2d = enum_auto() Conv3d = enum_auto() # OperationKindNames = { OperationKind.Gemm: 'gemm' , OperationKind.RankK: 'rank_k' , OperationKind.Rank2K: 'rank_2k' , OperationKind.Trmm: 'trmm' , OperationKind.Symm: 'symm' , OperationKind.Conv2d: 'conv2d' , OperationKind.Conv3d: 'conv3d' } # class Target(enum.Enum): library = enum_auto() # ArchitectureNames = { 50: 'maxwell', 60: 'pascal', 61: 'pascal', 70: 'volta', 75: 'turing', 80: 'ampere', 89: 'ada', 90: 'hopper' } # SharedMemPerCC = { 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM 80: 163, # 163KB of SMEM - 1KB reserved for the driver 86: 99, # 99KB of SMEM - 1KB reserved for the driver 87: 163, # 163KB of SMEM - 1KB reserved for the driver 89: 99, # 99KB of SMEM - 1KB reserved for the driver 90: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### # def SubstituteTemplate(template, values): text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() Universal3x = enum_auto() SparseUniversal3x = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.Universal3x: "gemm", GemmKind.SparseUniversal3x: "spgemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped", } # class RankKKind(enum.Enum): Universal = enum_auto() # RankKKindNames = { RankKKind.Universal: "rank_k" } # class TrmmKind(enum.Enum): Universal = enum_auto() # TrmmKindNames = { TrmmKind.Universal: "trmm" } # class SymmKind(enum.Enum): Universal = enum_auto() # SymmKindNames = { SymmKind.Universal: "symm" } # class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() # EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', } # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() Horizontal = enum_auto() StridedDgradIdentity1 = enum_auto() StridedDgradIdentity4 = enum_auto() StridedDgradHorizontal = enum_auto() StreamK = enum_auto() # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', } # class GroupScheduleMode(enum.Enum): Device = enum_auto(), Host = enum_auto() # GroupScheduleModeTag = { GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' } # ShortGroupScheduleModeNames = { GroupScheduleMode.Device: 'Device', GroupScheduleMode.Host: 'Host' } ################################################################################################### # class ConvKind(enum.IntEnum): Fprop = 0 Dgrad = 1 Wgrad = 2 # ConvKindTag = { ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' } ConvKindNames = { ConvKind.Fprop: 'fprop', ConvKind.Dgrad: 'dgrad', ConvKind.Wgrad: 'wgrad', } class ConvMode(enum.IntEnum): CrossCorrelation = 0 Convolution = 1 # class IteratorAlgorithm(enum.Enum): Analytic = 0 Optimized = 1 FixedChannels = 2 FewChannels = 3 FixedStrideDilation = 4 # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: 'analytic', IteratorAlgorithm.Optimized: 'optimized', IteratorAlgorithm.FixedChannels: 'fixed_channels', IteratorAlgorithm.FewChannels: 'few_channels', IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' } # class StrideSupport(enum.Enum): Strided = 0 Unity = 1 Fixed = 2 # StrideSupportTag = { StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' } StrideSupportNames = { StrideSupport.Strided: '', StrideSupport.Unity: 'unity_stride', StrideSupport.Fixed: 'fixed_stride' } # class GroupMode(enum.Enum): NoneGroup = enum_auto() # dense conv (G=1) SingleGroup = enum_auto() # grouped convolution (single group per CTA) MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) # GroupModeTag = { GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', } GroupModeNames = { GroupMode.NoneGroup: '', GroupMode.SingleGroup: 'single_group', GroupMode.MultipleGroup: 'multiple_group', GroupMode.Depthwise: 'depthwise', } ################################################################################################### # class MathInstruction: def __init__(self, instruction_shape, \ element_a, element_b, element_accumulator, \ opcode_class, math_operation = MathOperation.multiply_add \ ): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation # class TileDescription: def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]): self.threadblock_shape = threadblock_shape self.tile_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute self.cluster_shape = cluster_shape def procedural_name(self): if self.minimum_compute_capability >= 90: return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( tbm = self.threadblock_shape[0], tbn = self.threadblock_shape[1], tbk = self.threadblock_shape[2], cm = self.cluster_shape[0], cn = self.cluster_shape[1], ck = self.cluster_shape[2], s = self.stages) else: return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class Direct2dConvFixedStrideDilationTileDescription: def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] self.threadblock_output_shape = threadblock_output_shape self.filter_shape = filter_shape self.stages = stages self.warp_count = warp_count self.stride = stride self.dilation = dilation self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.threadblock_output_shape[0], self.threadblock_output_shape[1], self.threadblock_output_shape[2], self.threadblock_output_shape[3], self.stages, self.filter_shape[0], self.filter_shape[1]) # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], self.stride[1], self.dilation[0], self.dilation[1]) return str_name # class Direct2dConvFixedStrideDilationTileDescription: def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] self.threadblock_output_shape = threadblock_output_shape self.filter_shape = filter_shape self.stages = stages self.warp_count = warp_count self.stride = stride self.dilation = dilation self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.threadblock_output_shape[0], self.threadblock_output_shape[1], self.threadblock_output_shape[2], self.threadblock_output_shape[3], self.stages, self.filter_shape[0], self.filter_shape[1]) # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], self.stride[1], self.dilation[0], self.dilation[1]) return str_name # class TensorDescription: def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.alignment = alignment self.complex_transform = complex_transform # class SymmetricTensorDescription: def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): self.element = element self.layout = layout self.fill_mode = fill_mode self.alignment = alignment self.complex_transform = complex_transform self.side_mode = side_mode # class TriangularTensorDescription: def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.side_mode = side_mode self.fill_mode = fill_mode self.diag_type = diag_type self.alignment = alignment self.complex_transform = complex_transform # def CalculateSmemUsage(operation): cta_shape = operation.tile_description.threadblock_shape stages = operation.tile_description.stages if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) if DataTypeSize[operation.A.element] == 32: elements_per_8b_md = 2 elif DataTypeSize[operation.A.element] == 4: elements_per_8b_md = 8 else: elements_per_8b_md = 4 smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md else: # Few BLAS3 operations only have A tensor data_type_size_a = DataTypeSize[operation.A.element] data_type_size_b = DataTypeSize[operation.A.element] if operation.is_mixed_input(): data_type_size_b = DataTypeSize[operation.B.element] smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ data_type_size_b * cta_shape[1] * cta_shape[2] // 8 smem_usage = smem_per_stage * stages return (smem_usage >> 10) class GemmUniversalMode(enum.IntEnum): """ Types corresponding to GemmUniversalMode """ Gemm = 0 GemmSplitKParallel = 1 Batched = 2 Array = 3 class SplitKMode(enum.IntEnum): """ Types corresponding to SplitKMode """ NoneSplitK = 0 Serial = 1 Parallel = 2