Skip to content

Commit

Permalink
more text on seq parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
brunomaga committed Sep 12, 2024
1 parent 9c2c9ce commit 0496011
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 882 deletions.
45 changes: 45 additions & 0 deletions _drafts/2024-07-11-GPT-lite-sequence-parallelism.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,53 @@ class MultiHeadAttention(nn.Module):

Note that you can add several improvements to the communication, such as sending `q`, `k` and `v` simultaneously, ou asynchronously.

## Ring Attention with Blockwise Transformers

The whole rationale was presented in the paper [Self-attention Does Not Need $$O(n^2)$$ Memory](https://arxiv.org/abs/2112.05682). The rationale is the following. Usually the attention for a given head, given a query $$q$$, key $$k$$ and value $$v$$, the attention calculation can be reduced to:

$$
\begin{align*}
Attntion(Q, K, V) & = softmax \left(QK^T \right) V \\
& = softmax \left( \sum_i q^T k_i\right) v_i & \text{(expanding dot-product)}\\
& = \sum_i \left( \frac{\exp(q^T k_i)}{ \sum_j \exp(q^T k_j)} \right) v_i & \text{(expanding definition of softmax)}\\
& = \frac{ \sum_i \exp(q^T k_i) v_i }{ \sum_j \exp(q^T k_j) }.
\end{align*}
$$

The smart bit here is that we do not need to load the full $$v$$ and $$k$$ tensors or store the full attention matrix $$QK^T$$ im memory. Instead:
1. we iterate over the $$i$$-th element of the tensors $$v$$ and $$k$$, and perform the accumulations $$v^{\star} \leftarrow v^{\star} + \exp(q^T k_i) v_i$$ (top of the fraction), and $$s^{\star} \leftarrow s^{\star} + \exp(q^T k_i)$$ (bottom of the fraction).
2. after processing all keys and values, we divide $$\frac{v^{\star}}{s^{\star}}$$ to get the final value.

Instead of having a single process iterating over the elements of the query and value tensors, we now want to perform sequence parallelism. We want to perform sequence parallelism, by splitting the tensors $$q$$, $$k$$ and $$v$$ across $$P$$ processes, in the time dimension. That's where ring attention comes into play - original paper [Ring Attention with Blockwise Transformers for Near-Infinite Context](https://arxiv.org/abs/2310.01889) based on [Blockwise Parallel Transformer for Large Context Models
](https://arxiv.org/abs/2305.19370). To start, **each process hold a non-overlapping timeframe (block) of the $$q$$, $$k$$ and $$v$$ tensors**. After that, blocks for the $$q$$ and $$v$$ tensors will be sent to all processes, iteratively, in a ring fashion: at each step, each process sends its block of keys and values to the next process, and receives it from the previous. After $$P-1$$ communication steps, all processes will have received the full $$k$$ and $$v$$ tensors, in chunks. This pattern can be illustrated as:

{: style="text-align:center; font-size: small;"}
<img width="100%" height="100%" src="/assets/GPT-lite-distributed/ring_attention.png"/>

{: style="text-align:center; font-size: small;"}
Overview of the Ring Attention algorithm. **Before Ring Attention:** the initial view of the input tensor, distributed across 4 (color-coded) gpus, split by the time (T) dimension. **1st Ring Attention Step:** the first step of the ring attention. Each process holds its part of the Query, Value and Key tensors. Each process computes the block attention for those tensors. Asynchronously, processes perform an async send/recv of the Key and Value tensors to the next/previous process in the communication ring (clockwise). **2nd, 3rd, and 4th Ring Attention steps:** Each process its original Query block, and the previous processes' Key and Value blocks. Processes compute again the block attention for its Query and the received Key and Values. **After Ring Attention**: the Multi-head attention output is time-split across processes, similarly to the initial data format.

From the standpoint of a process, holding its own block of $$q$$ and receiving the full $$k$$ and $$v$$ allows the computation of $$v^{\star}$$, $$s^{\star}$$ and the attention output for its local block.

The forward pass will simply take as input the query, key and values tensor for each process and compute the output of the attention and the [LogSumExp](https://en.wikipedia.org/wiki/LogSumExp) of the [Softmax function](https://en.wikipedia.org/wiki/Softmax_function), ie:

```python
block_out, block_lse, = attn_forward( q, k, v)
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
```

Note that the [backward propagation](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#background) takes as input the gradient of the loss with respect to the output ($$\nabla_{out} Loss$$ or `dout` in the code below), and returns the derivatives of the loss with respect to the parameters of the functions (gradients $$\nabla_{q} Loss$$, $$\nabla_{k} Loss$$, $$\nabla_{v} Loss$$, or `dq`, `dk`, `dv`). Something similar to:


```python
block_dq, block_dk, block_dv = attn_backward(dout, q, k, v, out, lse)
dq += block_dq
dv += block_qv
dk += block_qk
```

Similarly to the forward pass, we will have to *rotate* `v` and `k` and also `dv` and `dk`.

The queries tensor is always local to a process, and we can compute `dq` by summing `block_dk` for every step of the backward (with different `k` and `v`). The caveat is on computing `dk` and `dv`: because `k` and `v` *rotate* in the process ring in every iteration, then the gradients `dk` and `dv` for a given timeframe (ie process) will be computed at the process that computes that timeframe's `dv` and `dk`. Because of that, we will have an accumulator of gradients that also *rotates in the circle*. After all rotations, it will return the correct value to the process holding that timeframe.

Note that the implementation of `backward` may be confusing. According to the [documentation](https://pytorch.org/docs/stable/generated/torch.autograd.Function.backward.html), "it must accept a context `ctx` as the first argument, followed by as many outputs as the `forward()` returned (`None` will be passed in for non tensor outputs of the forward function), and it should return as many tensors, as there were inputs to `forward()`".
2 changes: 1 addition & 1 deletion _posts/2024-05-10-Training-Variable-Length.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ In the next sections we will detail the latter option.
The tricky bit in the algorithm above is the distributed sorting that performs the transformation from steps b) to d). There are [other distributed sorting algorithms]({{ site.baseurl }}{% post_url 2014-06-21-Distributed-Sort %}) that one could use. But here we will implement the Distributed Sample Sorting algorithm, as it scales well for a large number of processes. The workflow is the following:

{: style="text-align:center; font-size: small;"}
<img width="60%" height="60%" src="/assets/Distributed-Sort/sample_sort.png">
<img width="70%" height="70%" src="/assets/Distributed-Sort/sample_sort.png">

The python implementtion of this distributed sorting algorithm is provided below.

Expand Down
Binary file modified assets/GPT-lite-distributed/ring_attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 0496011

Please sign in to comment.