Kernels for LayerNorm forward pass

Suggest an Edit

Reading time: 7 min

Introduction

The Layer Normalization (LayerNorm) operation applies normalization across the last D dimensions of the activation tensor as described in this foundational paper by Ba et al. (2016). The normalization equation is given below:

$$y = \frac{x - \mathbb{E}[x]}{\sqrt{Var[x] + \epsilon}} \times \gamma + \beta$$

where, \(\mathbb{E}[z]\) and \(Var[z]\) are the expectation and variance of random variable \(z\), respectively. Note that in the above \(\epsilon\) is a small term for to avoid division by zero errors, whereas \(\gamma\) and \(\beta\) are scale and shift parameters, respectively.

This pocket reference outlines and provides a detailed explanation of a series of CUDA kernel implementations of LayerNorm forward pass based on the llm.c github repository. Please refer to the Layer Normalization pocket reference for conceptual understanding and other details about the operation. For the purpose of this pocket reference, lets implement kernels for LayerNorm in the Transformer architecture for language modeling. The input to the LayerNorm operation is expected to be a tensor of shape \((B, T, C)\) as input, where:

  • \(B\) is the batch size
  • \(T\) is the sequence length
  • \(C\) is the hidden dimension size.

LayerNorm is applied along the last dimension \(C\). For benchmarking purposes, we use the following configuration:

  • \(B = 8\)
  • \(T = 1024\)
  • \(C = 768\)

The following table shows memory bandwidth for each kernel on a A40 GPU for block size 512. The last column shows improvement over the first kernel:

Kernel #Bandwidth (GB/s)Improvement
141.43-
2201.254.9x
3362.108.7x
4432.0310.4x
5538.8813x

Kernel 1

The first kernel is a copy of the CPU implementation. It parallelizes over the first 2 dimensions, \(B\) and \(T\), where \(N = B*T\). A single thread (see Figure-1a) is responsible for normalizing one segment of size C, hence it loops over all elements in that segment. The kernel code is broken down into 4 steps:

  1. Mean calculation

    $$\mathbb{E}[x] = \frac{1}{C} \sum_{i=1}^{C} x_i$$

  2. Variance and reciprocal of standard deviation (rstd) calculation

    $$Var[x] = \frac{1}{C} \sum_{i=1}^{C} (x_i - \mathbb{E}[x])^2$$

    $$rstd[x] = \frac{1}{\sqrt{Var[x] + \epsilon}}$$

  3. Apply mean and variance normalization and then scale and shift with the learnable weight and bias parameters

    $$y_i = ((x_i - \mathbb{E}[x]) * rstd[x]) * \gamma_i + \beta_i$$

  4. Store mean and rstd for backward pass

The kernel uses a 1D grid and block as shown in Figure-1a. Also note that all operations are implemented in a single kernel.

layernorm_kernel1
Figure-1a: Kernel 1 Illustration.
layernorm_kernel1_code
Figure-1b: Kernel 1 Code.

Kernel 2

In Kernel 2, steps 1, 2 and 3 are implemented as separate kernels. For the mean and rstd kernels, each block is responsible for one segment of C instead of each thread (see Figure-2a) which allows for further parallelization. Whereas for the normalization kernel (step 3), each thread calculates one output element.

Since both the mean and rstd calculations involve the sum operation, they make use of thread coarsening and reduction. In thread coarsening, each thread sums corresponding elements and stores it in a shared memory array (same size as the thread block). In reduction, the elements in the shared array are iteratively reduced to obtain the final sum. For more details, see the thread coarsening and reduction pocket references.

These optimizations lead to an improvement of ~5x over Kernel 1 (for block size 512).

layernorm_kernel2
Figure-2a: Kernel 2 Illustration - mean and rstd kernels.
layernorm_kernel2_code1
layernorm_kernel2_code2
Figure-2b: Kernel 2 Code.

Kernel 3

Kernel 3 introduces the use of cooperative groups, allowing us to utilize thread groups of arbitrary sizes (multiples of 2) that are not limited to the thread block. The cooperative groups concept provides thread group classes (tiled_partition<N>(g)) with useful methods such as thread_rank(), which returns the id of the current thread in that group (similar to threadId.x), and reduce(), which performs a reduction operation (similar to that described in Figure-2a) on the values assigned to variables for threads in that group. The cooperative groups objects are defined within the cooperative_groups namespace.

This kernel uses a thread group (or tile) size of 32 to align with the number of threads in a warp (let's refer to this thread group as a warp). Hence, one warp is responsible for one segment of C in Kernel 3 (see Figure-3a - A warp of size 4 is used for simplicity). Also note that all operations are again combined in a single kernel.

This kernel also includes a few additional changes:

  1. Use of the `restrict` keyword: This allows the compiler to perform further optimizations through reduced memory accesses and computation.
  2. Use of Cache Operators: __stcs() and __ldcs() limit cache pollution.

These optimizations lead to an improvement of ~1.8x over Kernel 2 (for block size 512).

layernorm_kernel3
Figure-3a: Kernel 3 Illustration.
layernorm_kernel3_code
Figure-3b: Kernel 3 Code.

Kernel 4

This kernel is similar to Kernel 3, except for the formula used to calculate variance. The variance is calculated as follows, leading to fewer subtraction operations:

$$Var[x] = \mathbb{E}[x^2] - (\mathbb{E}[x])^2$$

This simple change also leads to a small improvement of ~1.2x over Kernel 3 (for block size 512).

Kernel 5

The final kernel operates in two stages. Similar to Kernel 2, each block is responsible for one segment of C. In stage 1, even though thread coarsening is done on the block level, the first reduction is done on the warp level. This sum is written into a shared memory array whose size is equal to the number of warps. In stage 2, the threads in the first warp are re-used to perform another warp reduction on the shared array to obtain the final sum. There is no thread coarsening for this stage. See Figure-4a for the complete flow.

The final kernel improves by ~1.25x over Kernel 4 and ~13x over the first kernel (for block size 512).

layernorm_kernel5
Figure-4a: Kernel 5 Illustration.
layernorm_kernel5_code
Figure-4b: Kernel 5 Code.

Summary

The following figure provides a summary of the memory bandwidth for all kernels on a A40 GPU across different block sizes:

layernorm_bandwidth_line_chart_a40
Figure-5: A40 Memory Bandwidth Summary.
  1. Code for LayerNorm forward kernels from the llm.c github repository
  2. Layer Normalization paper
  3. CUDA C++ Programming Guide
  4. CUDA Parallel Thread Execution (PTX) Guide

Contributors: