forked from zhuzilin/ring-flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark_qkvpacked_func.py
89 lines (76 loc) · 2.45 KB
/
benchmark_qkvpacked_func.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from flash_attn import flash_attn_qkvpacked_func
import torch
import torch.distributed as dist
from ring_flash_attn import (
ring_flash_attn_qkvpacked_func,
zigzag_ring_flash_attn_qkvpacked_func,
stripe_flash_attn_qkvpacked_func,
)
import torch.cuda
def benchmark(f, num_iter=100, forward_only=True, log=True):
dtype = torch.bfloat16
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
batch_size = 1
seqlen = 1024 * 8
nheads = 5
d = 128
dropout_p = 0
causal = True
deterministic = False
assert seqlen % (2 * world_size) == 0
assert d % 8 == 0
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
)
dout = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
begin = torch.cuda.Event(enable_timing=True)
begin.record()
if forward_only:
with torch.no_grad():
for _ in range(num_iter):
_ = f(
qkv,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
else:
for _ in range(num_iter):
qkv.grad = None
out = f(
qkv,
dropout_p=dropout_p,
causal=causal,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=deterministic,
return_attn_probs=False,
)
out.backward(dout)
end = torch.cuda.Event(enable_timing=True)
end.record()
torch.cuda.synchronize(device=device)
time = begin.elapsed_time(end) / 1000.0
if rank == 0 and log:
print(f"{num_iter / time:.3f} iter/s, {time:.3f} sec")
if __name__ == "__main__":
dist.init_process_group("nccl")
rank = dist.get_rank()
forward_only = False
for f in [
flash_attn_qkvpacked_func,
ring_flash_attn_qkvpacked_func,
zigzag_ring_flash_attn_qkvpacked_func,
stripe_flash_attn_qkvpacked_func,
]:
torch.cuda.empty_cache()
if rank == 0:
print(f"# {f.__name__}")
benchmark(f, forward_only=forward_only, log=False)
benchmark(f, forward_only=forward_only, log=True)