# # \file generator.py # # \brief Generates the CUTLASS Library's instances # import re ################################################################################################### import enum # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. # # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 def enum_auto() -> int: global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i ################################################################################################### # class GeneratorTarget(enum.Enum): Library = enum_auto() # GeneratorTargetNames = {GeneratorTarget.Library: "library"} # ################################################################################################### # class DataType(enum.Enum): b1 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() tf32 = enum_auto() f64 = enum_auto() cf16 = enum_auto() cbf16 = enum_auto() cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() cu32 = enum_auto() cu64 = enum_auto() invalid = enum_auto() # ShortDataTypeNames = { DataType.s32: "i", DataType.f16: "h", DataType.f32: "s", DataType.f64: "d", DataType.cf32: "c", DataType.cf64: "z", } # DataTypeNames = { DataType.b1: "b1", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", DataType.cbf16: "cbf16", DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", DataType.cs32: "cs32", DataType.cs64: "cs64", } DataTypeTag = { DataType.b1: "cutlass::uint1b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex", DataType.cbf16: "cutlass::complex", DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", DataType.cs32: "cutlass::complex", DataType.cs64: "cutlass::complex", } DataTypeSize = { DataType.b1: 1, DataType.u4: 4, DataType.u8: 4, DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, DataType.cbf16: 32, DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, DataType.cs32: 64, DataType.cs64: 128, } ################################################################################################### # class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() # ComplexTransformTag = { ComplexTransform.none: "cutlass::ComplexTransform::kNone", ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate", } # RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), (DataType.f64, DataType.cf64), ] # def is_complex(data_type): for r, c in RealComplexBijection: if data_type == c: return True return False # def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: return c return DataType.invalid # def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: return r return DataType.invalid # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() ################################################################################################### # class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() xor_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() # MathOperationTag = { MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd", MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate", MathOperation.xor_popc: "cutlass::arch::OpXorPopc", MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16", MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16", MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex", MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex", } ################################################################################################### # class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() ColumnMajorInterleaved2 = enum_auto() RowMajorInterleaved2 = enum_auto() ColumnMajorInterleaved32 = enum_auto() RowMajorInterleaved32 = enum_auto() ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNHWC = enum_auto() TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNC4HW4 = enum_auto() TensorC4RSK4 = enum_auto() TensorNC8HW8 = enum_auto() TensorNC16HW16 = enum_auto() TensorNC32HW32 = enum_auto() TensorNC64HW64 = enum_auto() TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() TensorK4RSC4 = enum_auto() TensorCK4RS4 = enum_auto() TensorCK8RS8 = enum_auto() TensorCK16RS16 = enum_auto() # LayoutTag = { LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", LayoutType.RowMajor: "cutlass::layout::RowMajor", LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>", LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>", LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>", LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>", LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>", LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>", LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC", LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC", LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW", LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC", LayoutType.TensorNC4HW4: "cutlass::layout::TensorNCxHWx<4>", LayoutType.TensorC4RSK4: "cutlass::layout::TensorCxRSKx<4>", LayoutType.TensorNC8HW8: "cutlass::layout::TensorNCxHWx<8>", LayoutType.TensorNC16HW16: "cutlass::layout::TensorNCxHWx<16>", LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>", LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>", LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>", LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>", LayoutType.TensorK4RSC4: "cutlass::layout::TensorKxRSCx<4>", LayoutType.TensorCK4RS4: "cutlass::layout::TensorCKxRSx<4>", LayoutType.TensorCK8RS8: "cutlass::layout::TensorCKxRSx<8>", LayoutType.TensorCK16RS16: "cutlass::layout::TensorCKxRSx<16>", } # TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC, } # ShortLayoutTypeNames = { LayoutType.ColumnMajor: "n", LayoutType.ColumnMajorInterleaved32: "n2", LayoutType.ColumnMajorInterleaved32: "n32", LayoutType.ColumnMajorInterleaved64: "n64", LayoutType.RowMajor: "t", LayoutType.RowMajorInterleaved2: "t2", LayoutType.RowMajorInterleaved32: "t32", LayoutType.RowMajorInterleaved64: "t64", LayoutType.TensorNHWC: "nhwc", LayoutType.TensorNDHWC: "ndhwc", LayoutType.TensorNCHW: "nchw", LayoutType.TensorNGHWC: "nghwc", LayoutType.TensorNC4HW4: "nc4hw4", LayoutType.TensorC4RSK4: "c4rsk4", LayoutType.TensorNC8HW8: "nc8hw8", LayoutType.TensorNC16HW16: "nc16hw16", LayoutType.TensorNC32HW32: "nc32hw32", LayoutType.TensorNC64HW64: "nc64hw64", LayoutType.TensorC32RSK32: "c32rsk32", LayoutType.TensorC64RSK64: "c64rsk64", LayoutType.TensorK4RSC4: "k4rsc4", LayoutType.TensorCK4RS4: "ck4rs4", LayoutType.TensorCK8RS8: "ck8rs8", LayoutType.TensorCK16RS16: "ck16rs16", } # ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): "n", (LayoutType.ColumnMajor, ComplexTransform.conj): "c", (LayoutType.RowMajor, ComplexTransform.none): "t", (LayoutType.RowMajor, ComplexTransform.conj): "h", } ################################################################################################### # class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() WmmaTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: "simt", OpcodeClass.TensorOp: "tensorop", OpcodeClass.WmmaTensorOp: "wmma_tensorop", } OpcodeClassTag = { OpcodeClass.Simt: "cutlass::arch::OpClassSimt", OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp", OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp", } ################################################################################################### # class OperationKind(enum.Enum): Gemm = enum_auto() Conv2d = enum_auto() # OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"} # class Target(enum.Enum): library = enum_auto() ArchitectureNames = { 50: "maxwell", 60: "pascal", 61: "pascal", 70: "volta", 75: "turing", 80: "ampere", } ################################################################################################### # def SubstituteTemplate(template, values): text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() SplitKParallel = enum_auto() GemvBatchedStrided = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.SplitKParallel: "gemm_split_k_parallel", GemmKind.GemvBatchedStrided: "gemv_batched_strided", } # class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() BiasAddLinearCombination = enum_auto() BiasAddLinearCombinationRelu = enum_auto() BiasAddLinearCombinationHSwish = enum_auto() BiasAddLinearCombinationClamp = enum_auto() BiasAddLinearCombinationReluClamp = enum_auto() BiasAddLinearCombinationHSwishClamp = enum_auto() # EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination", EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp", EpilogueFunctor.BiasAddLinearCombination: "cutlass::epilogue::thread::BiasAddLinearCombination", EpilogueFunctor.BiasAddLinearCombinationRelu: "cutlass::epilogue::thread::BiasAddLinearCombinationRelu", EpilogueFunctor.BiasAddLinearCombinationHSwish: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwish", EpilogueFunctor.BiasAddLinearCombinationClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationClamp", EpilogueFunctor.BiasAddLinearCombinationReluClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp", EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp", } # ShortEpilogueNames = { EpilogueFunctor.LinearCombination: "id", EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish", EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu", EpilogueFunctor.BiasAddLinearCombinationClamp: "id", EpilogueFunctor.BiasAddLinearCombinationHSwish: "hswish", EpilogueFunctor.BiasAddLinearCombinationRelu: "relu", EpilogueFunctor.BiasAddLinearCombination: "id", } # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() ConvFpropNCxHWx = enum_auto() ConvFpropTrans = enum_auto() ConvDgradNCxHWx = enum_auto() ConvDgradTrans = enum_auto() DepthwiseConvolutionFprop = enum_auto() DepthwiseConvolutionDgrad = enum_auto() DepthwiseConvolutionWgrad = enum_auto() # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>", SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", SwizzlingFunctor.ConvFpropNCxHWx: "cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle", SwizzlingFunctor.ConvFpropTrans: "cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle", SwizzlingFunctor.ConvDgradNCxHWx: "cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle", SwizzlingFunctor.ConvDgradTrans: "cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle", SwizzlingFunctor.DepthwiseConvolutionFprop: "cutlass::conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle", SwizzlingFunctor.DepthwiseConvolutionDgrad: "cutlass::conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle", SwizzlingFunctor.DepthwiseConvolutionWgrad: "cutlass::conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle", } ################################################################################################### class ConvType(enum.Enum): Convolution = enum_auto() BatchConvolution = enum_auto() Local = enum_auto() LocalShare = enum_auto() DepthwiseConvolution = enum_auto() ConvTypeTag = { ConvType.Convolution: "cutlass::conv::ConvType::kConvolution", ConvType.BatchConvolution: "cutlass::conv::ConvType::kBatchConvolution", ConvType.Local: "cutlass::conv::ConvType::kLocal", ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare", ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution", } # class ConvKind(enum.Enum): Fprop = enum_auto() Dgrad = enum_auto() Wgrad = enum_auto() # ConvKindTag = { ConvKind.Fprop: "cutlass::conv::Operator::kFprop", ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad", ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad", } ConvKindNames = { ConvKind.Fprop: "fprop", ConvKind.Dgrad: "dgrad", ConvKind.Wgrad: "wgrad", } # class IteratorAlgorithm(enum.Enum): Analytic = enum_auto() Optimized = enum_auto() # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic", IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized", } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: "analytic", IteratorAlgorithm.Optimized: "optimized", } # class StrideSupport(enum.Enum): Strided = enum_auto() Unity = enum_auto() # StrideSupportTag = { StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided", StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity", } StrideSupportNames = {StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride"} class SpecialOptimizeDesc(enum.Enum): NoneSpecialOpt = enum_auto() ConvFilterUnity = enum_auto() DeconvDoubleUpsampling = enum_auto() SpecialOptimizeDescNames = { SpecialOptimizeDesc.NoneSpecialOpt: "none", SpecialOptimizeDesc.ConvFilterUnity: "conv_filter_unity", SpecialOptimizeDesc.DeconvDoubleUpsampling: "deconv_double_upsampling", } SpecialOptimizeDescTag = { SpecialOptimizeDesc.NoneSpecialOpt: "cutlass::conv::SpecialOptimizeDesc::NONE", SpecialOptimizeDesc.ConvFilterUnity: "cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY", SpecialOptimizeDesc.DeconvDoubleUpsampling: "cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING", } class ImplicitGemmMode(enum.Enum): GemmNT = enum_auto() GemmTN = enum_auto() ImplicitGemmModeNames = { ImplicitGemmMode.GemmNT: "gemm_nt", ImplicitGemmMode.GemmTN: "gemm_tn", } ImplicitGemmModeTag = { ImplicitGemmMode.GemmNT: "cutlass::conv::ImplicitGemmMode::GEMM_NT", ImplicitGemmMode.GemmTN: "cutlass::conv::ImplicitGemmMode::GEMM_TN", } ################################################################################################### # class MathInstruction: def __init__( self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation=MathOperation.multiply_add, ): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation # class TileDescription: def __init__( self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, ): self.threadblock_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): return "%dx%d_%dx%d" % ( self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages, ) # class TensorDescription: def __init__( self, element, layout, alignment=1, complex_transform=ComplexTransform.none ): self.element = element self.layout = layout self.alignment = alignment self.complex_transform = complex_transform ################################################################################################### class GlobalCnt: cnt = 0