-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Short vs long sequences performance #89
Comments
Ahh okay so I think this has to do with your specific Timer class. This Timer is missing a few things, its measuring all the Host side overhead to call the cuda kernel and it is not explicitly synchronizing the GPU. For triton kernels there is forsure some extra overhead to the CudaLaunch, and as well there is some additional overhead to validating guards for a compiled region. That being said for typical training or inference (with cuda-graphs) there is enough work for the host to outpace the device. This is generally a safe assumption but there are a some scenarios where that isnt' true. If we change the Timer to only measure GPU time, with something like: import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
import time
from tabulate import tabulate
from torch.nn.attention.flex_attention import (
flex_attention,
create_block_mask,
create_mask,
)
torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)
print(f"Torch version: {torch.__version__}")
torch._dynamo.config.suppress_errors = True
# Utility: Timer context
# from torch._inductor.utils import do_bench_using_profiling
# def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# """Thin wrapper around do_bench_using_profiling"""
# no_args = lambda: func(*args, **kwargs)
# time = do_bench_using_profiling(no_args)
# return time * 1e3
from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds
benchmark_fn = benchmark_cuda_function_in_microseconds
# Function: Causal mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
results = []
for batch_size in batch_sizes:
for seq_len in sequence_lengths:
# Setup inputs
q = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16
).to("cuda")
k = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16
).to("cuda")
v = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16
).to("cuda")
block_mask = create_block_mask(
causal_mask,
B=None,
H=None,
Q_LEN=seq_len,
KV_LEN=seq_len,
device="cuda",
)
mask = create_mask(causal_mask, None, None, seq_len, seq_len, device="cuda")
# Benchmark flex_attention
flex_times = []
for _ in range(n_runs):
flex_time = benchmark_fn(
flex_attention,
q,
k,
v,
score_mod=None,
block_mask=block_mask,
# warmup=True,
)
flex_times.append(flex_time)
flex_avg_time = (sum(flex_times) / n_runs) * 1000 # Convert to ms
# Benchmark scaled_dot_product_attention with mask
sdpa_times = []
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
for _ in range(n_runs):
sdpa_time = benchmark_fn(
scaled_dot_product_attention,
q,
k,
v,
attn_mask=mask,
# warmup=True,
)
sdpa_times.append(sdpa_time)
sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000 # Convert to ms
# Benchmark scaled_dot_product_attention with causal fa2
causal_fa2_times = []
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
for _ in range(n_runs):
causal_fa2_time = benchmark_fn(
scaled_dot_product_attention,
q,
k,
v,
is_causal=True,
# warmup=True,
)
causal_fa2_times.append(causal_fa2_time)
causal_fa2_avg_time = (
sum(causal_fa2_times) / n_runs
) * 1000 # Convert to ms
# Log results
results.append(
{
"Batch Size": batch_size,
"Seq Length": seq_len,
"FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
"SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
"FA2 Avg Time (ms)": f"{causal_fa2_avg_time:.2f}",
}
)
return results
# Main script
if __name__ == "__main__":
batch_sizes = [
4,
8,
]
sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
n_runs = 5 # Number of runs for averaging
results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)
# Generate table
print("\n=== Benchmark Results ===")
print(tabulate(results, headers="keys", tablefmt="grid")) You will get a better measure of the Kernel time. Another subtlety is that you will be generating kernels with dynamic shapes that are slightly slower than those specialized on shapes. I set dynamic=False and then expanding the cache size |
Hi there,
I've recently started dabbling with
flex_attention
and its modularity, which looks really interesting. Thanks for the nice work!But before I even had the time to dive into the
score_mod
implementations, I observed some unexpected slowness.Specifically, there seems to be some kind of overhead which makes
flex_attention
way slower than standard SDPA or causal FA2 for (relatively) short sequences.(For context, I'm pondering whether it could replace/simplify our (rather complex) MHA implementation in eole.)
Reproduction
Here is a small benchmarking script
Which yields the following results:
--> we can see that, between the ~1024/~2048 tokens mark, the gap becomes rather significant, independent of the batch size.
(It might not seem like a lot, but this seems to build up into some quite significant slowdown in a complete run.)
Notes
block_mask
computation is also not considered in the measured times;Questions
Thanks!
The text was updated successfully, but these errors were encountered: