blog

Illustrated Tensor Parallelism

As deep learning models continue to grow, a single GPU often is not enough to handle them. Even high-end hardware quickly runs out of memory when training models with billions of parameters. Tensor parallelism is a technique designed to overcome this limitation. There are other strategies like data, pipeline and context parallelism. We will get to them in future writings.

How It Works

Instead of having each GPU hold the entire model, we split the weights across GPUs. Each GPU does its part of the computation, and they talk to each other to combine results.

There are a few different ways to do this split:

Column Parallelism

Imagine splitting the weight matrix vertically.

Column Parallelism

I am using three GPUs for the diagrams. Here’s how it works: the input gets copied to all GPUs (broadcasted). Each GPU multiplies it by its columns of weights and produces part of the output. Then GPUs do an all-gather their results and materialize the full output matrix.

Interestingly, if you treat each local result as having zeros elsewhere, an all-reduce would also work. Since the local results occupy separate columns, summing them gives the full output.

Row Parallelism

Now split the weight matrix horizontally.

Row Parallelism

The input gets sharded column-wise across GPUs. Each GPU multiplies its columns of input by its rows of weights and gets a partial output that is of same shape as expected output. Since each GPU only computed part of the result, they need to sum all their partial outputs together using an all-reduce operation to get the final answer.

Column + Row Parallelism

Combine column and row parallelism across sequential layers so outputs stay distributed between GPUs, avoiding expensive data transfers.

Column-Row Parallelism

After a column-parallel layer, the output is already sharded by columns across GPUs, which is perfect for the next row-parallel layer. This means we can skip the communication step and directly feed the intermediate results to the next layer.

TP in Transformer Blocks

For simplicity, let’s focus on decoder-only transformer blocks. And the main two parts of such a decoder block is: (1) Multi-headed Attention (MHA) and (2) Feed Forward Network (FFN).

Transformer Block

Borrowing the figures from the Megatron-LM paper - we will first take a look into how TP plays out in FFN and MHA.

FFN

The FFN sub-block is two matrix multiplications (GEMMs) with an activation (e.g. GeLU) and a dropout. Activation and dropout are element-wise operations, so no worries there.

Y = X @ A  # GEMM 1
Y = gelu(Y)
Z = Y @ B  # GEMM 2
Z = dropout(Z)

This is same as the “Column + Row” parallel case above. In the figure below The f function acts as broadcast and g as all-reduce during the forward pass.

MHA

Let’s assume the Q, K and V matrices are of shape (d, d). For n heads, these can be seen as concatenations of n smaller per-head projections of size (d, d/n).

Now, a simplified version of attention mechanism look like:

# b = batch
# s = sequence length

Q = X @ W_Q                         # (b, s, d)
K = X @ W_K
V = X @ W_V

# Split into heads
Q = split(Q, n)                     # (b, n, s, d/n)
K = split(K, n)
V = split(V, n)

# Scaled dot-product attention per head
A = Q @ K.T                         # (b, n, s, s)
A = softmax(A) / sqrt(d/n)
A = dropout(A)
Y = A @ V                           # (b, n, s, d/n)

# Concatenate heads and project
Y = concat(Y, dim=-1)               # (b, s, d)
Z = Y @ B                           # (b, s, d)
Z = dropout(Z)

For TP, we shard n heads into g GPUs, each GPU having n/g heads. Each GPU’s local projections has width (n/g) * (d/n) => d/g. We can think of as column-wise partition of QKV projection matrices as the figure below.

As in the “Column + Row” case, the local outputs Y are already column-partitioned (by heads), and we can keep the output projection B row-parallel.

Backpropagation

The explanations above are for forward pass, while f is doing identity (broadcast) and g is for all-reduce. During backward pass, f and g switch roles. More on that later.

Thanks