Skip to content

Commit

Permalink
finished ring attn, ulysses and moved files
Browse files Browse the repository at this point in the history
  • Loading branch information
brunomaga committed Sep 19, 2024
1 parent 26fe284 commit 1f0c23e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 63 deletions.
110 changes: 59 additions & 51 deletions assets/GPT-lite-distributed/ring_attention_sequence_parallelism.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

import torch
import torch.distributed as dist
Expand All @@ -15,6 +16,7 @@


class MultiHeadAttention(nn.Module):
""" A Ring Attention multi-head attention. Variable names follow GPT-lite's post """

def __init__( self, n_embd, d_head, n_heads = 8, dropout_p = 0.1, group = None ):
super().__init__()
Expand All @@ -26,17 +28,17 @@ def __init__( self, n_embd, d_head, n_heads = 8, dropout_p = 0.1, group = None )
self.values = nn.ModuleList([nn.Linear(n_embd, d_head, bias=False) for _ in range(n_heads)])
self.proj = nn.Linear(n_heads * d_head, n_embd)
self.dropout = nn.Dropout(dropout_p)
self.group = group # Ring Attention group group
self.group = group # Ring Attention group
if self.group is None:
self.group = dist.new_group(range(dist.get_world_size()))

class RingAttention(torch.autograd.Function):

@staticmethod
def acc_out_and_lse(out, lse, block_out, block_lse):
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
def accumulate_out_and_lse(out, lse, block_out, block_lse):
# source: https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795

if out is None: # first block, allocate results tensors
if out is None:
out = block_out
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)

Expand All @@ -53,25 +55,26 @@ def forward( ctx, q, k, v, group ):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous() # (B, T/P, H, E)

out = lse = None # accumulators
recv_k, recv_v = torch.empty_like(k), torch.empty(v) # recv buffers
recv_k, recv_v = torch.empty_like(k), torch.empty_like(v) # recv buffers

for _ in range(P): # do P ring steps
# "overlapping the communication of key-value blocks with the computation of blockwise attention."
all_reqs = MultiHeadAttention.isend_k_and_v(k, v, recv_k, recv_v, group)
for step in range(P): # do P ring steps
# send already the K and V for next step, asynchronously
reqs_k_v = MultiHeadAttention.isend_k_and_v(k, v, recv_k, recv_v, group)

# forward pass of attention function for the K, V, and Q for this block
block_out, _, _, _, _, block_lse, _, _ = fa._flash_attn_forward(q,k,v)
# compute attention output and softmax lse for current block
dropout_p, softmax_scale = 0, q.shape[-1] ** (-0.5)
kwargs = dict(causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False)
block_out, _, _, _, _, block_lse, _, _ = fa._flash_attn_forward(q,k,v, dropout_p, softmax_scale, **kwargs)

# update out and lse
out, lse = MultiHeadAttention.RingAttention.acc_out_and_lse(out, lse, block_out, block_lse)
out, lse = MultiHeadAttention.RingAttention.accumulate_out_and_lse(out, lse, block_out, block_lse)

# wait for K and V for the next iteration (final iteration will revert K and V to original proc)
for req in all_reqs:
req.wait()

lse = lse.squeeze(dim=-1).transpose(1, 2)
# wait for new K and V before starting the next iteration
[ req.wait() for req in reqs_k_v]
k, v = recv_k, recv_v

ctx.group = group # save for backward
out = out.to(dtype=q.dtype)
ctx.save_for_backward(q, k, v, out, lse)
return out

Expand All @@ -80,51 +83,54 @@ def backward(ctx, dout, *args):

P = ctx.group.size()
q, k, v, out, softmax_lse = ctx.saved_tensors
softmax_lse = softmax_lse.squeeze(dim=-1).transpose(1, 2)

block_dq, block_dk, block_dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)

dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
block_dq, block_dk, block_dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) # output buffers
dq, dk, dv = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) # accumulators of gradients

recv_k, recv_v = torch.empty_like(k), torch.empty_like(v) # recv buffers for K and V
recv_dk, recv_dv = torch.empty_like(dk), torch.empty_like(dv) # recv buffers for dK and dV

for step in range(P):
all_k_v_reqs = MultiHeadAttention.isend_k_and_v(k, v, recv_k, recv_v, group)

fa._flash_attn_backward(dout=dout, q=q, k=k, v=k, out=out, softmax_lse=softmax_lse,
dq=block_dq, dk=block_dk, dv=block_dv)
# send already the K and V for next step, asynchronously
reqs_k_v = MultiHeadAttention.isend_k_and_v(k, v, recv_k, recv_v, group)

# compute the gradients for the current block K, V and Q
dropout_p, softmax_scale = 0, q.shape[-1] ** (-0.5)
kwargs = dict(causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, deterministic=False, rng_state=None)
fa._flash_attn_backward(dout, q, k, k, out, softmax_lse, block_dq, block_dk, block_dv, dropout_p, softmax_scale, **kwargs)

# K and V are rotated, so dK and dV must also be rotated and accumulated in a ring fashion
if step > 0:
for req in all_dk_dv_reqs:
req.wait()
# wait for dK and dV from the previous steps, they're the dK and dV accumulators
[ req.wait() for req in reqs_dk_dv]
dk, dv = recv_dk, recv_dv

dq += block_dq
dk += block_dk
dv += block_dv

all_dk_dv_reqs = MultiHeadAttention.isend_k_and_v(dk, dv, recv_dk, recv_dv, group)
reqs_dk_dv = MultiHeadAttention.isend_k_and_v(dk, dv, recv_dk, recv_dv, group)

# wait for K and V for the next iteration
for req in all_k_v_reqs:
req.wait()
# wait for new K and V before starting the next iteration
[ req.wait() for req in reqs_k_v]
k, v = recv_k, recv_v

for req in all_dk_dv_reqs:
req.wait()
# before returning, wait for the last dK and dV, that relate to this process block
[ req.wait() for req in reqs_dk_dv]
dk, dv = recv_dk, recv_dv
return dq, dk, dv, None

@staticmethod
def isend_k_and_v( k, v, recv_k, recv_v, group):
P, rank = group.size(), group.rank()
dst = (rank + 1) % P
src = (rank - 1) % P
req_k_send = dist.P2POp(dist.isend, k, dst, group, 1)
req_k_recv = dist.P2POp(dist.irecv, recv_k, src, group, 1)
req_v_send = dist.P2POp(dist.isend, v, dst, group, 2)
req_v_recv = dist.P2POp(dist.irecv, recv_v, src, group, 2)
all_reqs = [req_k_send, req_k_recv, req_v_send, req_v_recv]
dist.batch_isend_irecv(all_reqs)
return all_reqs
dst = (rank + 1) % P # the rank of the next process
src = (rank - 1 -P) % P # the rank of the previous process
req_k_send = dist.P2POp(dist.isend, k, dst, group)
req_v_send = dist.P2POp(dist.isend, v, dst, group)
req_k_recv = dist.P2POp(dist.irecv, recv_k, src, group)
req_v_recv = dist.P2POp(dist.irecv, recv_v, src, group)
return dist.batch_isend_irecv([req_k_send, req_v_send, req_k_recv, req_v_recv])

def forward(self, x):
P, B, T, = self.group.size(), x.shape[0], x.shape[1] * self.group.size()
Expand All @@ -135,10 +141,11 @@ def forward(self, x):
v = torch.stack([v(x) for v in self.values], dim=0)

if P == 1:
out = self.flash_attn_func(q, k, v)
out = fa.flash_attn_func(q, k, v)
else:
out = MultiHeadAttention.RingAttention.apply( q, k, v, self.group)

out = out.permute(1, 2, 0, 3) # (H, B, T/P, E) -> (B, T/P, H, E)
out = out.reshape(B, T // P, -1) # (B, T/P, H, E) -> (B, T/P, H*E)
out = self.proj(out) # (B, T/P, H*E) -> (B, T/P, E)
out = self.dropout(out)
Expand All @@ -160,6 +167,7 @@ def forward(self, x):
return x



if __name__ == "__main__":

# set up network variables
Expand All @@ -168,34 +176,34 @@ def forward(self, x):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = f"cuda:{local_rank}"
group = dist.new_group(range(dist.get_world_size()))
dtype = torch.bfloat16

# model constants (use these or import from GPT-lite post). Naming matches post
P = dist.get_world_size()
B = 4
B = 4 # batch size
T = 2048 # the length of the sequence
H = 8 # the number of heads
E = 128 # the size of the head
n_embd = 256 # the hidden size of the model
n_blocks = 12 # number of transformer blocks

# sanity checks
assert T % P == 0, "seqlen must be divisible by number of processes"
assert H % P == 0, "n_heads must be divisible by number of processes"

x = torch.randint(0, 5, (B, T, n_embd)).to(device=device).float() # dummy input
y = torch.ones_like(x).float() # dummy label
x = torch.randint(0, 5, (B, T, n_embd)).to(device=device, dtype=dtype) # dummy input
y = torch.ones_like(x) # dummy label

# build model as sequence of blocks
blocks = nn.Sequential(*[Block(n_embd, E, H, group=group) for _ in range(n_blocks)]).to(device=device)
blocks = nn.Sequential(*[Block(n_embd, E, H, group=group) for _ in range(n_blocks)]).to(device=device, dtype=dtype)
blocks = DistributedDataParallel(blocks, device_ids=[local_rank], static_graph=True, process_group=group)
optimizer = torch.optim.Adam(blocks.parameters())

# run 10 random iterations
for i in range(10):
# run few iterations
for i in range(15):
out = blocks(x)
loss = nn.functional.mse_loss(out, y)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if dist.get_rank() == 0:
print(f"Iteration {i} loss: {loss}")

dist.barrier(group=group)
dist.destroy_process_group()
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import sys

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel

import flash_attn.flash_attn_interface as fa
from flash_attn.flash_attn_interface import flash_attn_func

# use FeedForwars from base GPTlite model from the GPT-lite post
current_dir = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -25,7 +26,7 @@ def __init__(self, n_embd=256, d_head=128, n_heads=8, dropout_p=0.1, group=None)
self.values = nn.ModuleList([nn.Linear(n_embd, d_head, bias=False) for _ in range(n_heads)])
self.proj = nn.Linear(n_heads * d_head, n_embd)
self.dropout = nn.Dropout(dropout_p)
self.group = group # Ulysses sequence parallelism group
self.group = group # Ulysses group
if self.group is None:
self.group = dist.new_group(range(dist.get_world_size()))

Expand Down Expand Up @@ -55,7 +56,7 @@ def backward(ctx, dout):

@staticmethod
def dist_view_swap(tensor: torch.Tensor, old_split_dim: int, new_split_dim: int, group: dist.ProcessGroup):
"""converts the distributed splie dimension of a tensor with shape (H, B, T, E) across P processes"""
"""swaps the distributed split dimension of a tensor of shape (H, B, T, E) across P processes"""
full_shape, P = list(tensor.shape), group.size()
full_shape[old_split_dim]*=P # full distributed shape
H, B, T, E = full_shape
Expand All @@ -78,12 +79,14 @@ def forward(self, x):
k = MultiHeadAttention.first_alltoall.apply(k, self.group)
v = MultiHeadAttention.first_alltoall.apply(v, self.group)

out = fa.flash_attn_func(q, k, v)[0]
dropout_p, softmax_scale = 0, q.shape[-1] ** (-0.5)
out = flash_attn_func(q, k, v, dropout_p, softmax_scale)

if P> None: # (H/P, B, T, E) -> (H, B, T/P, E)
if P > 1: # (H/P, B, T, E) -> (H, B, T/P, E)
out = MultiHeadAttention.second_alltoall.apply(out, self.group)

out = out.permute(1, 2, 0, 3).reshape(B, T // P, -1) # (H, B, T/P, E) -> (B, T/P, H, E) -> (B, T/P, H*E)
out = out.permute(1, 2, 0, 3) # (H, B, T/P, E) -> (B, T/P, H, E)
out = out.reshape(B, T // P, -1) # (B, T/P, H, E) -> (B, T/P, H*E)
out = self.proj(out) # (B, T/P, H*E) -> (B, T/P, E)
out = self.dropout(out)
return out
Expand Down Expand Up @@ -112,10 +115,11 @@ def forward(self, x):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = f"cuda:{local_rank}"
group = dist.new_group(range(dist.get_world_size()))
dtype = torch.bfloat16

# model constants (use these or import from GPT-lite post). Naming matches post
P = dist.get_world_size()
B = 4
B = 4 # batch size
T = 2048 # the length of the sequence
H = 8 # the number of heads
E = 128 # the size of the head
Expand All @@ -126,20 +130,24 @@ def forward(self, x):
assert T % P == 0, "seqlen must be divisible by number of processes"
assert H % P == 0, "n_heads must be divisible by number of processes"

x = torch.randint(0, 5, (B, T, n_embd)).to(device=device).float() # dummy input
y = torch.ones_like(x).float() # dummy label
x = torch.randint(0, 5, (B, T, n_embd)).to(device=device, dtype=dtype) # dummy input
y = torch.ones_like(x) # dummy label

# build model as sequence of blocks
blocks = nn.Sequential(*[Block(n_embd, E, H, group=group) for _ in range(n_blocks)]).to(device=device)
blocks = nn.Sequential(*[Block(n_embd, E, H, group=group) for _ in range(n_blocks)]).to(device=device, dtype=dtype)
blocks = DistributedDataParallel(blocks, device_ids=[local_rank], static_graph=True, process_group=group)
optimizer = torch.optim.Adam(blocks.parameters())

# run 10 random iterations
for i in range(10):
# run few iterations
for i in range(15):
out = blocks(x)
loss = nn.functional.mse_loss(out, y)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
dist.barrier()
if dist.get_rank() == 0:
print(f"Iteration {i} loss: {loss}")

dist.barrier(group=group)
dist.destroy_process_group()

0 comments on commit 1f0c23e

Please sign in to comment.