# Basic example of using the CUTLASS Python interface for Conv2d

This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run Conv2d. 

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/03_basic_conv2d.ipynb)


## Prerequisites for running on Colab
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
!#nvidia-smi

If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:

In [None]:
!#pip install nvidia-cutlass

## General setup
We first import various packages needed for the example and construct the input and output tensors that will be used in our example.

In [None]:
import torch
import random

import cutlass

# This controls whether the C++ GEMM declaration will be printed at each step. 
# Set to `false` to omit this information.
print_module = True

# Input tensor: [N, H, W, C] under the channel-last layout
N, H, W, C = [32, 28, 28, 64]

# Weight tensor: [K, R, S, C] under the channel-last layout
K, R, S = [128, 3, 3]

# Stride, and padding
stride = (2, 2)
padding = (1, 1)
dilation = (1, 1)

# Compute the output size [N, P, Q, K]
N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

dtype = torch.float16
type_A = torch.float16
type_B = torch.float16
type_C = torch.float16
type_D = torch.float16

torch.manual_seed(1234)

input = torch.ceil(
 torch.empty(size=(N, C, H, W), dtype=type_A, device="cuda").uniform_(-4.5, 3.5)
).to(memory_format=torch.channels_last)
weight = torch.ceil(
 torch.empty(size=(K, C, R, S), dtype=type_B, device="cuda").uniform_(-4.5, 3.5)
).to(memory_format=torch.channels_last)
tensor_C = torch.ceil(
 torch.empty(size=(N, K, P, Q), dtype=type_B, device="cuda").uniform_(-4.5, 3.5)
).to(memory_format=torch.channels_last)
output = torch.zeros_like(tensor_C)

alpha = 1.0
beta = 0.0

## Declaring and running a Conv2d Fprop

We first show you how to run a Conv2d in the forward propagation. To get started, one only needs to provide the tensors declared above to the `cutlass.op.Conv2dFprop` call. This sets up a default Conv2d fprop operation for the given device on which you are running. 

Assuming that we are runing on SM80, the default is a Conv2d that leverages FP16 Tensor Core operations.

Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed.

In [None]:
# Specifying `element_accumulator` is not required if it is the same as `element`
plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)

There are many other ways to construct a plan from `cutlass.op.Conv2dFprop` (e.g., by specifying the types of each operand, by providing representative tensors as input). For more details on these, see the documentation in the `cutlass.op.Conv2dFprop` constructor.

We then compare the output to running the Conv2d using PyTorch. PyTorch use NCHW layout by default, so permutations are required.

In [None]:
output_torch = alpha * torch.ops.aten.conv2d(
 input, weight, stride=stride, padding=padding, dilation=dilation
) + beta * tensor_C

assert torch.equal(output_torch, output)

Note that one could use the same kernel just declared for tensors provided by other frameworks beyond PyTorch, such as NumPy.

## Declaring and running Conv2d Dgrad and Wgrad

The Python interface also supports declaring and running backward kernels of Conv2d. To begin with, we construct the tensors for the gradient of input, output, and weight.

In [None]:
grad_output = torch.ceil(
 torch.empty(size=(N, K, P, Q), dtype=type_A, device="cuda").uniform_(-4.5, 3.5)
).to(memory_format=torch.channels_last)
grad_input = torch.zeros_like(input)
grad_weight = torch.zeros_like(weight)

tensor_C_dgrad = torch.ceil(
 torch.empty(size=(N, C, H, W), dtype=type_A, device="cuda").uniform_(-4.5, 3.5)
).to(memory_format=torch.channels_last)
tensor_C_wgrad = torch.ceil(
 torch.empty(size=(K, C, R, S), dtype=type_B, device="cuda").uniform_(-4.5, 3.5)
).to(memory_format=torch.channels_last)

The script below gives a simple example of computing a data gradient via the CUTLASS Python interface and via PyTorch.

In [None]:
plan_dgrad = cutlass.Conv2dDgrad(element=dtype, element_accumulator=torch.float32)
plan_dgrad.run(grad_output, weight, tensor_C_dgrad, grad_input, stride, padding, dilation, alpha, beta, print_module=print_module)

grad_input_torch = alpha * torch.nn.grad.conv2d_input(
 (N, C, H, W),
 weight, grad_output,
 stride=stride, padding=padding
) + beta * tensor_C_dgrad

assert torch.equal(grad_input_torch, grad_input)

The script below gives a simple example of computing a weight gradient via the CUTLASS Python interface and via PyTorch.

In [None]:
plan_wgrad = cutlass.Conv2dWgrad(element=dtype, element_accumulator=torch.float32)
plan_wgrad.run(grad_output, input, tensor_C_wgrad, grad_weight, stride, padding, dilation, alpha, beta, print_module=print_module)

grad_weight_torch = alpha * torch.nn.grad.conv2d_weight(
 input, (K, C, R, S), grad_output,
 stride=stride, padding=padding
) + beta * tensor_C_wgrad

assert torch.equal(grad_weight_torch, grad_weight)

## Running non-default Conv2ds

The previous examples showed how it is simple to get starting running a default Conv2d kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the Conv2d? CUTLASS Python interface exposes mutable parameters that can be set after the `plan` initialization. We summarize these in the table below.

|Parameter|Description|
| -- | -- |
|`tile_description`|The threadblock tile size, warp count, software pipeline stages, and instruction shape|
|`iterator_algorithm`|The iterator algorithm used to access the source operands|
|`swizzling_stride`|The stride of the threadblock swizzling functor|
|`split-K`|Partitions the reduction dimension to different threadblocks|

### Tile Description

The `tile_description` defines the tiling size of each threadblock, the warp count along each dimension of the tile, the software pipeline stages, and the instruction size. Under the hood, CUTLASS enumerates the different Conv2d configuration parameters for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernel (e.g., threadblock and warp shape).

In [None]:
plan.opclass = "tensor_op"
tiles = plan.tile_descriptions()
print(f'{len(tiles)} tile descriptions returned')
num_print = 10
print(f'First {num_print} tile descriptions are:')
for td in tiles[:num_print]:
 print(td)

Next, we'll pick one of these configurations at random and compile and run it.

In [None]:
random.seed(42)
idx = random.randint(0, len(tiles)-1)
td = tiles[idx]
print(f'Tile description {idx} is: {td}')
plan.tile_description = td
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)
assert torch.equal(output_torch, output)

Besides tile descriptions enumerated by CUTLASS, the users can also explicitly set the `threadblockshape`, `warp_shape`, `stages`, `instruction_shape`, and `cluster_shape`. If the configuration is invalid, an exception will be raised at `plan.run()` and the detailed compilation error will be stored in `./cutlass_python_compilation_error.txt` for debugging.

In [None]:
if plan.cc == 70:
 plan.tile_description = {
 "threadblock_shape": [64, 256, 32],
 "warp_count": [1, 4, 1],
 "stages": 2,
 "instruction_shape": [8, 8, 4], # optional,
 "cluster_shape": [1, 1, 1] # optional, only [1, 1, 1] is supported currently
 }
elif plan.cc == 75:
 plan.tile_description = {
 "threadblock_shape": [128, 64, 32],
 "warp_count": [2, 1, 1],
 "stages": 2,
 "instruction_shape": [16, 8, 8], # optional,
 "cluster_shape": [1, 1, 1] # optional, only [1, 1, 1] is supported currently
 }
elif plan.cc == 80:
 plan.tile_description = {
 "threadblock_shape": [128, 128, 64],
 "warp_count": [2, 2, 1],
 "stages": 4,
 "instruction_shape": [16, 8, 16], # optional,
 "cluster_shape": [1, 1, 1] # optional, only [1, 1, 1] is supported currently
 }
elif plan.cc == 86:
 plan.tile_description = {
 "threadblock_shape": [128, 64, 64],
 "warp_count": [2, 2, 1],
 "stages": 3,
 "instruction_shape": [16, 8, 16],
 "cluster_shape": [1, 1, 1]
 }

plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)
assert torch.equal(output_torch, output)

### Iterator Algorithm

The iterator algorithm describes how sources are loaded from memory. There are some iterator algorithms optimized for specific alignments and input/output channels that have better performance. The table below illustrates the available iterator algorithms.

|Conv Kind | Iterator Algorithm | Description |
| -- | -- | -- |
|Fprop | "analytic" | Functionally correct in all cases but lower performance |
| | "optimized" | Optimized for and requires `R <= 32`, `S<= 32`, and `C % alignment_input == 0`|
| | "few_channels" | optimized for small `C` and requires `C % alignment_input == 0`|
| | "fixed_channels" | optimized for small `C` and requires `C == alignment_input` |
|Dgrad | "analytic" | Functionally correct in all cases but lower performance |
| | "optimized" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|
|Wgrad | "analytic" | Functionally correct in all cases but lower performance |
| | "optimized" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|

By default, the Python interface will automatically propose a suitable iterator algorithm based on the input tensors in `plan.run()`. However, the user can also specify the desired iterator algorithm as follows

In [None]:
plan.iterator_algorithm = "analytic"
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)
assert torch.equal(output_torch, output)

If the iterator algorithm is invalid for the problem size in `plan.run()`, an exception will be raised.

### Swizzling Stride
The swizzling changes how the tile are mapped to threadblocks to improve the L2 Locality. Given a swizzling stride `N`, the threadblock `(tb_x, tb_y)` computes tile `(tb_x / N, tb_y * N + (tb_x % N))`. Currently, stride values of `1`, `2`, `4`, and `8` are supported for `fprop`, `wgrad`, and `1`, and `4` for `dgrad`. The swizzling stride can be set with:

In [None]:
plan.swizzling_stride = 4
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)
assert torch.equal(output_torch, output)

### Split-K
Split-K is usually applied when the Conv2d has small spatial dimensions and large reduction dimension to ensure good utilization. It further partitions the reduction dimension to different threadblocks. The CUTLASS Python interface supports two types of split-K strategies: `Parallel`, and `Serial`. 
* `Parallel`: the partial results from different threadblocks are stored in a temporary buffer in the global memory. When the Conv2d is done, a separate reduction kernel is created and launched to reduce the partial results.
* `Serial`: A semaphore is used to coordinate the order of different threadblocks adding their partial results to a given output tile. A separate kernel does not need to be launched for prforming the reduction.

While all `fprop`, `dgrad`, and `wgrad` support split-K, here we use `wgrad` as an example. 

In [None]:
# Parallel Split-K with 5 slices
grad_weight_parallel = torch.zeros_like(grad_weight)
plan_wgrad.run(
 grad_output, input, tensor_C_wgrad, grad_weight_parallel, 
 stride, padding, dilation, alpha, beta, print_module=print_module, split_k=("parallel", 5))
assert torch.equal(grad_weight_torch, grad_weight_parallel)

# Serial Split-K with 3 slices
grad_weight_serial = torch.zeros_like(grad_weight)
plan_wgrad.run(
 grad_output, input, tensor_C_wgrad, grad_weight_serial, 
 stride, padding, dilation, alpha, beta, print_module=print_module, split_k=("serial", 3))
assert torch.equal(grad_weight_torch, grad_weight_serial)