Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.6.3's flash_attn_varlen_func runs faster than v2.7.0.post2's flash_Attn_varlen_func on H100 #1338

Open
complexfilter opened this issue Nov 16, 2024 · 3 comments

Comments

@complexfilter
Copy link

complexfilter commented Nov 16, 2024

I found v2.6.3's flash_attn_varlen_func runs faster than v2.7.0.post2's flash_Attn_varlen_func on H100.

code

import torch

from hopper.flash_attn_interface import flash_attn_func, flash_attn_varlen_func

import triton

def get_tensors(batch_size, seq_len, head_size, dim):
    torch.manual_seed(42)
    q = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
    k = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
    v = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
    return q, k, v

if __name__=="__main__":
    batch_size = 1
    seq_len = 33+40*34*60
    head_size = 28
    dim = 128
    q, k, v = get_tensors(batch_size, seq_len, head_size, dim)
    fn = lambda : flash_attn_func(q, k, v, softmax_scale=None, causal=False)
    warmup=20
    rep=100
    ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
    print(f"FA3 fwd with full attention: {ms[1]:.5f} ({ms[0]:.5f}, {ms[2]:.5f}) | ")

    cu_seqlens_q = torch.tensor([0, seq_len], device = q.device, dtype=torch.int32)
    cu_seqlens_k = torch.tensor([0, seq_len], device = q.device, dtype=torch.int32)
    fn = lambda : flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q, cu_seqlens_k, seq_len, seq_len)
    warmup=20
    rep=100
    ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
    print(f"FA3 varlen fwd with full attention: {ms[1]:.5f} ({ms[0]:.5f}, {ms[2]:.5f}) | ")

    x  = flash_attn_func(q, k, v, softmax_scale=None, causal=False, deterministic=False)
    x_ = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q, cu_seqlens_k, seq_len, seq_len, softmax_scale=None, causal=False, deterministic=False)


    print(torch.norm(x_[0].unsqueeze(0) - x[0], p='fro') / torch.norm(x[0], p='fro'))

Result from using v2.6.3 on H100:

FA3 fwd with full attention: 147.76518 (147.76518, 147.76518) | 
FA3 varlen fwd with full attention: 150.84227 (150.84227, 150.84227) | 
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<DivBackward0>)

Result from using v2.7.0.post2 on H100:

FA3 fwd with full attention: 155.47057 (155.47057, 155.47057) | 
FA3 varlen fwd with full attention: 221.30321 (221.30321, 221.30321) | 
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<DivBackward0>)

The runtime is 150ms vs 221ms.

@tridao
Copy link
Contributor

tridao commented Nov 16, 2024

Please try compiling with CUDA 12.3

@complexfilter
Copy link
Author

complexfilter commented Nov 16, 2024

Please try compiling with CUDA 12.3

I believe my cuda version is 12.4.

Fri Nov 15 16:50:13 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:86:00.0 Off |                    0 |
| N/A   26C    P0             80W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Not sure if CUDA 12.4 was the issue.

@tridao
Copy link
Contributor

tridao commented Nov 16, 2024

What matters is is the version of nvcc, not the CUDA driver. You can install cuda software toolkit (including nvcc) to whichever driver version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants