{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "1ef96b3f", "metadata": {}, "source": [ "# Basic example of using the CUTLASS Python interface\n", "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/00_basic_gemm.ipynb)\n" ] }, { "cell_type": "markdown", "id": "df94d7e6", "metadata": {}, "source": [ "## Prerequisites for running on Colab\n", "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." ] }, { "cell_type": "code", "execution_count": null, "id": "71c7a069", "metadata": {}, "outputs": [], "source": [ "!#nvidia-smi" ] }, { "cell_type": "markdown", "id": "cf16785d", "metadata": {}, "source": [ "If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:" ] }, { "cell_type": "code", "execution_count": null, "id": "c819bb68", "metadata": {}, "outputs": [], "source": [ "!#pip install nvidia-cutlass" ] }, { "attachments": {}, "cell_type": "markdown", "id": "962324fd", "metadata": {}, "source": [ "## General setup\n", "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." ] }, { "cell_type": "code", "execution_count": null, "id": "0e324219", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import random\n", "\n", "import cutlass\n", "\n", "# This controls whether the C++ GEMM declaration will be printed at each step. \n", "# Set to `False` to omit this information.\n", "print_module = True\n", "\n", "m = 128\n", "n = m\n", "k = m\n", "\n", "dtype = np.float16\n", "type_A = np.float16\n", "type_B = np.float16\n", "type_C = np.float16\n", "type_D = np.float16\n", "\n", "np.random.seed(1234)\n", "random.seed(1234)\n", "scope_min = -4\n", "scope_max = 4\n", "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", "\n", "alpha = np.float16(1.)\n", "beta = np.float16(0.)\n", "\n", "tensor_D = np.zeros(tensor_C.shape).astype(type_D)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f2c7bf48", "metadata": {}, "source": [ "## Declaring and running a GEMM\n", "To get started, one only needs to provide the tensors declared above to the `cutlass.op.Gemm` call.\n", "This sets up a default GEMM operation for the given device on which you are running.\n", "\n", "Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.\n", "\n", "Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed." ] }, { "cell_type": "code", "execution_count": null, "id": "0dfd8975", "metadata": {}, "outputs": [], "source": [ "# We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,\n", "# specifying `element_accumulator` is not required if it is the same as `element`\n", "plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor, element_accumulator=np.float32)\n", "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4a5856de", "metadata": {}, "source": [ "There are many other ways to construct a plan from `cutlass.op.Gemm` (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the `cutlass.op.Gemm` constructor." ] }, { "attachments": {}, "cell_type": "markdown", "id": "945478ef", "metadata": {}, "source": [ "We then compare the output to running the GEMM using NumPy." ] }, { "cell_type": "code", "execution_count": null, "id": "6b669de6", "metadata": {}, "outputs": [], "source": [ "tensor_D_numpy = (alpha * (tensor_A @ tensor_B)) + (beta * tensor_C)\n", "np.testing.assert_array_equal(tensor_D, tensor_D_numpy)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ee5cbbbe", "metadata": {}, "source": [ "Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy." ] }, { "attachments": {}, "cell_type": "markdown", "id": "b6c86493", "metadata": {}, "source": [ "## Changing operation modes\n", "By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to `cutlass.op.Gemm` is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.\n", "\n", "The operation mode currently in use can be returned via the `plan.opclass` property. In this case Tensor Core operations." ] }, { "cell_type": "code", "execution_count": null, "id": "529fda93", "metadata": {}, "outputs": [], "source": [ "print(plan.opclass)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6d27c575", "metadata": {}, "source": [ "Suppose that we don't want to use Tensor Cores for this GEMM. One can change to using CUTLASS's SIMT GEMMs by setting the plan's `opclass` field.\n", "\n", "As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASS's SIMT GEMMs.\n", "\n", "Also notice that, this time around, we provided tensor parameters to `plan.run()`. One is free to provide different parameters to `plan.run()` than were passed in at the initial call to `cutlass.op.Gemm`, provided that the passed-in tensors have the same data type and layout as those passed in on intialization." ] }, { "cell_type": "code", "execution_count": null, "id": "6a44d35b", "metadata": {}, "outputs": [], "source": [ "tensor_D_simt = np.zeros(tensor_C.shape).astype(type_D)\n", "plan.opclass = cutlass.OpcodeClass.Simt\n", "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_simt, alpha, beta, print_module=print_module)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "639dcb59", "metadata": {}, "source": [ "If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal." ] }, { "cell_type": "code", "execution_count": null, "id": "9b480853", "metadata": {}, "outputs": [], "source": [ "np.testing.assert_array_equal(tensor_D, tensor_D_simt)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0cce1eae", "metadata": {}, "source": [ "## Running cached kernels\n", "You may have noticed that the `plan.run()` calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.\n", "\n", "CUTLASS caches compiled binaries so that recompilation isn't necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call `plan.run()` again (with a different set of tensor parameters), you'll find the call to return much faster." ] }, { "cell_type": "code", "execution_count": null, "id": "f8051e5e", "metadata": {}, "outputs": [], "source": [ "m = 2400\n", "n = 3232\n", "k = 4096\n", "\n", "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", "tensor_D = np.zeros(tensor_C.shape).astype(type_D)\n", "\n", "alpha = np.float16(1.)\n", "beta = np.float16(2.)\n", "\n", "plan.opclass = cutlass.OpcodeClass.TensorOp\n", "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "52a4e318", "metadata": {}, "source": [ "## Running non-default GEMMs\n", "The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?\n", "\n", "Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape)." ] }, { "cell_type": "code", "execution_count": null, "id": "1c593be1", "metadata": {}, "outputs": [], "source": [ "tiles = plan.tile_descriptions()\n", "print('{} tile descriptions returned'.format(len(tiles)))\n", "num_print = 10\n", "print('First {} tile descriptions are:'.format(num_print))\n", "for td in tiles[:num_print]:\n", " print(td)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "dc3ad875", "metadata": {}, "source": [ "Next, we'll pick one of these configurations at random and compile and run it." ] }, { "cell_type": "code", "execution_count": null, "id": "a8dc5287", "metadata": {}, "outputs": [], "source": [ "tiles = [td for td in tiles if td.threadblock_shape[0] >= 128]\n", "idx = random.randint(0, len(tiles)-1)\n", "td = tiles[idx]\n", "print('Tile description {} is: {}'.format(idx, td))\n", "plan.compile(td)\n", "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c5a8b534", "metadata": {}, "source": [ "One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:" ] }, { "cell_type": "code", "execution_count": null, "id": "e5e88d17", "metadata": {}, "outputs": [], "source": [ "# Stream K is exposed through the threadblock swizzle method for pre-SM90 kernels,\n", "# and via the tile_scheduler attribute of the TileDescription for post-SM90 kernels\n", "if plan.cc < 90:\n", " plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n", " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)\n", "else:\n", " # Stream-K is currently only supported for warp-specialized cooperative kernels\n", " td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative\n", " td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative\n", " td.tile_scheduler = cutlass.TileSchedulerType.StreamK\n", "\n", " plan.compile(td)\n", " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "5a8ba2ba", "metadata": {}, "source": [ "## Handling errors\n", "The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n", "\n", "Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user. Uncomment and run the code below to see this error." ] }, { "cell_type": "code", "execution_count": null, "id": "fe7d0e42", "metadata": {}, "outputs": [], "source": [ "# td = tiles[0]\n", "# td.stages = 8\n", "# plan.compile(td)" ] }, { "cell_type": "markdown", "id": "0fff34a4", "metadata": {}, "source": [ "## Specializations for other data types\n", "\n", "Various CUTLASS kernels specialized for specific data types can also be run via the Python interface.\n", "\n", "For example, the code below shows how to declare and run a GEMM using the 3xTF32 feature (see corresponding C++ example [here](https://github.com/NVIDIA/cutlass/blob/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu))." ] }, { "cell_type": "code", "execution_count": null, "id": "338ad890", "metadata": {}, "outputs": [], "source": [ "from cutlass.backend.utils.device import device_cc\n", "\n", "# 3xTF32 requires SM80 or higher\n", "if device_cc() >= 80:\n", " plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n", " plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n", "\n", " # Create input/output tensors in FP32\n", " A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n", " C, D = [np.zeros((128, 128)).astype(np.float32) for _ in range(2)]\n", "\n", " # Run the GEMM\n", " plan.run(A, B, C, D, print_module=print_module)" ] }, { "cell_type": "markdown", "id": "65531df1", "metadata": {}, "source": [ "Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)" ] }, { "cell_type": "code", "execution_count": null, "id": "776f1d8d", "metadata": {}, "outputs": [], "source": [ "try:\n", " import torch\n", "except ImportError:\n", " print(\"PyTorch is not available. Skipping FP8 example\")\n", " import sys; sys.exit(0)\n", "\n", "if not hasattr(torch, \"float8_e4m3fn\"):\n", " print(\"Version of PyTorch does not have the float8_e4m3fn data type. Skipping FP8 example\")\n", " import sys; sys.exit(0)\n", "\n", "# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n", "if device_cc() >= 90:\n", " plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n", " layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n", " layout_C=cutlass.LayoutType.ColumnMajor)\n", "\n", " # Create input/output tensors in FP8\n", " A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", " C, D = [torch.zeros((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", "\n", " # Run the GEMM\n", " plan.run(A, B, C, D, print_module=print_module)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "vscode": { "interpreter": { "hash": "0466d96796c9cd8f7a1cad264ff326ececc950ba2420e0256d5105fc1a3c6e70" } } }, "nbformat": 4, "nbformat_minor": 5 }