# Example of using elementwise activation functions in the CUTLASS Python interface
This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/01_epilogue.ipynb)


## Prerequisites for running on Colab
This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected.

In [None]:
!#nvidia-smi

If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:

In [None]:
!#pip install nvidia-cutlass

## General setup
We first import various packages needed for the example and construct the input and output tensors that will be used in our example.

In [None]:
import numpy as np

import cutlass

# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to
# omit this information.
print_module = True

m = 256
n = m
k = m

type_A = np.float16
type_B = np.float16
type_C = np.float16
type_D = np.float16

np.random.seed(1234)
scope_min = -4
scope_max = 4
tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))
tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))
tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))

alpha = np.float16(1.)
beta = np.float16(0.)

tensor_D = np.zeros(tensor_C.shape).astype(type_D)

## Run a GEMM with an identity activation function
To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified.

In [None]:
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)
plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)

## Run a GEMM with a ReLU element-wise activation function
CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:
```
D = alpha * (A @ B) + beta * C
D = act(D)
```

Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.

This is easy to do in CUTLASS. One only needs to set the plan's `activation` field.

In [None]:
tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)
plan.activation = "relu"
plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)

We can now verify that the result of the GEMM that used a ReLU activation function:

In [None]:
relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D
np.testing.assert_array_equal(relu_ref, tensor_D_relu)

## Other element-wise activation functions
CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method.

In [None]:
activations = plan.activations()
for activation in activations:
 print(activation)

We can then run each of them:

In [None]:
for activation in activations:
 print('=============================================================================================')
 print(f'Compiling and running activation {activation}')
 print('=============================================================================================')
 plan.activation = activation
 plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)

To add an activation with parameter such as `leaky_relu`, a tuple should be provided containing the activation function name and the (or a list of) parameter.

In [None]:
negative_slope = 0.5
plan.activation = ("leaky_relu", negative_slope)
plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)