Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Short vs long sequences performance #89

Open
francoishernandez opened this issue Dec 12, 2024 · 1 comment
Open

Short vs long sequences performance #89

francoishernandez opened this issue Dec 12, 2024 · 1 comment
Labels
question Further information is requested

Comments

@francoishernandez
Copy link

francoishernandez commented Dec 12, 2024

Hi there,

I've recently started dabbling with flex_attention and its modularity, which looks really interesting. Thanks for the nice work!

But before I even had the time to dive into the score_mod implementations, I observed some unexpected slowness.

Specifically, there seems to be some kind of overhead which makes flex_attention way slower than standard SDPA or causal FA2 for (relatively) short sequences.

(For context, I'm pondering whether it could replace/simplify our (rather complex) MHA implementation in eole.)

Reproduction

Here is a small benchmarking script

import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
import time
from tabulate import tabulate

from torch.nn.attention.flex_attention import (
    flex_attention,
    create_block_mask,
    create_mask,
)
flex_attention = torch.compile(flex_attention)
print(f"Torch version: {torch.__version__}")

torch._dynamo.config.suppress_errors = True

# Utility: Timer context
class Timer:
    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args):
        self.duration = time.time() - self.start

# Utility: Benchmark a function
def benchmark_fn(func, *args, warmup=False, **kwargs):
    if warmup:
        func(*args, **kwargs)  # Warmup run
    with Timer() as t:
        result = func(*args, **kwargs)
    return result, t.duration

# Function: Causal mask
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
    results = []

    for batch_size in batch_sizes:
        for seq_len in sequence_lengths:
            # Setup inputs
            q = torch.randn(batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16).to("cuda")
            k = torch.randn(batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16).to("cuda")
            v = torch.randn(batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16).to("cuda")

            block_mask = create_block_mask(
                causal_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, device="cuda"
            )
            mask = create_mask(causal_mask, None, None, seq_len, seq_len, device="cuda")

            # Benchmark flex_attention
            flex_times = []
            for _ in range(n_runs):
                _, flex_time = benchmark_fn(
                    flex_attention, q, k, v, score_mod=None, block_mask=block_mask, warmup=True
                )
                flex_times.append(flex_time)
            flex_avg_time = (sum(flex_times) / n_runs) * 1000  # Convert to ms

            # Benchmark scaled_dot_product_attention with mask
            sdpa_times = []
            with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
                for _ in range(n_runs):
                    _, sdpa_time = benchmark_fn(
                        scaled_dot_product_attention, q, k, v, attn_mask=mask, warmup=True
                    )
                    sdpa_times.append(sdpa_time)
                sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000  # Convert to ms

            # Benchmark scaled_dot_product_attention with causal fa2
            causal_fa2_times = []
            with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
                for _ in range(n_runs):
                    _, causal_fa2_time = benchmark_fn(
                        scaled_dot_product_attention, q, k, v, is_causal=True, warmup=True
                    )
                    causal_fa2_times.append(causal_fa2_time)
                causal_fa2_avg_time = (sum(causal_fa2_times) / n_runs) * 1000  # Convert to ms

            # Log results
            results.append(
                {
                    "Batch Size": batch_size,
                    "Seq Length": seq_len,
                    "FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
                    "SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
                    "FA2 Avg Time (ms)": f"{causal_fa2_avg_time:.2f}",
                }
            )

    return results

# Main script
if __name__ == "__main__":
    batch_sizes = [4, 8, 16, 20, 24, 32]
    sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
    n_runs = 5  # Number of runs for averaging

    results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)

    # Generate table
    print("\n=== Benchmark Results ===")
    print(tabulate(results, headers="keys", tablefmt="grid"))

Which yields the following results:

=== Benchmark Results ===
+--------------+--------------+----------------------+----------------------+---------------------+
|   Batch Size |   Seq Length |   FLEX Avg Time (ms) |   SDPA Avg Time (ms) |   FA2 Avg Time (ms) |
+==============+==============+======================+======================+=====================+
|            4 |          128 |                 0.08 |                 0.07 |                0.03 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |          256 |                 0.22 |                 0.09 |                0.05 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |          512 |                 0.25 |                 0.18 |                0.09 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |         1024 |                 0.39 |                 0.5  |                0.22 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |         2048 |                 0.87 |                 1.66 |                0.67 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |         4096 |                 2.6  |                 6.37 |                2.41 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |         8192 |                 9.25 |                25.97 |                8.69 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            4 |        16384 |                33.61 |               108.24 |               32.58 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |          128 |                 0.15 |                 0.07 |                0.03 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |          256 |                 0.24 |                 0.11 |                0.06 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |          512 |                 0.3  |                 0.25 |                0.12 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |         1024 |                 0.51 |                 0.77 |                0.32 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |         2048 |                 1.46 |                 3.17 |                1.23 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |         4096 |                 4.83 |                12.85 |                4.61 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |         8192 |                17.4  |                52.21 |               16.67 |
+--------------+--------------+----------------------+----------------------+---------------------+
|            8 |        16384 |                64.67 |               220.82 |               64.08 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |          128 |                 0.16 |                 0.09 |                0.05 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |          256 |                 0.26 |                 0.16 |                0.09 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |          512 |                 0.38 |                 0.42 |                0.2  |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |         1024 |                 0.89 |                 1.62 |                0.67 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |         2048 |                 2.63 |                 6.28 |                2.39 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |         4096 |                 9.24 |                25.77 |                8.82 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |         8192 |                32.87 |               106.29 |               32.64 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           16 |        16384 |               126.91 |               440.25 |              127.86 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |          128 |                 0.15 |                 0.09 |                0.05 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |          256 |                 0.27 |                 0.17 |                0.1  |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |          512 |                 0.42 |                 0.51 |                0.24 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |         1024 |                 1.05 |                 2    |                0.82 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |         2048 |                 3.21 |                 7.88 |                2.97 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |         4096 |                11.31 |                32.49 |               10.94 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |         8192 |                41.85 |               132.69 |               40.59 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           20 |        16384 |               159.21 |               552.21 |              159.14 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |          128 |                 0.17 |                 0.1  |                0.06 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |          256 |                 0.29 |                 0.2  |                0.12 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |          512 |                 0.46 |                 0.59 |                0.28 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |         1024 |                 1.22 |                 2.37 |                0.96 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |         2048 |                 3.79 |                 9.47 |                3.52 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |         4096 |                13.65 |                38.77 |               12.91 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |         8192 |                49.21 |               160.47 |               48.65 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           24 |        16384 |               190.65 |               668.23 |              191.95 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |          128 |                 0.18 |                 0.11 |                0.07 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |          256 |                 0.32 |                 0.24 |                0.15 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |          512 |                 0.62 |                 0.85 |                0.39 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |         1024 |                 1.53 |                 3.13 |                1.27 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |         2048 |                 4.94 |                12.66 |                4.7  |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |         4096 |                17.39 |                52.1  |               17.09 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |         8192 |                65.65 |               214.36 |               64.52 |
+--------------+--------------+----------------------+----------------------+---------------------+
|           32 |        16384 |               253.8  |               891.13 |              255.73 |
+--------------+--------------+----------------------+----------------------+---------------------+

--> we can see that, between the ~1024/~2048 tokens mark, the gap becomes rather significant, independent of the batch size.
(It might not seem like a lot, but this seems to build up into some quite significant slowdown in a complete run.)

Notes

  • This ran on an RTX3090, with torch 2.6.0.dev20241211+cu118;
  • I added some warmup call to remove potential compile time out of the observations;
  • block_mask computation is also not considered in the measured times;

Questions

  1. Am I doing something wrong?
  2. Is this behaviour expected?

Thanks!

@drisspg
Copy link
Contributor

drisspg commented Dec 15, 2024

Ahh okay so I think this has to do with your specific Timer class. This Timer is missing a few things, its measuring all the Host side overhead to call the cuda kernel and it is not explicitly synchronizing the GPU.

For triton kernels there is forsure some extra overhead to the CudaLaunch, and as well there is some additional overhead to validating guards for a compiled region. That being said for typical training or inference (with cuda-graphs) there is enough work for the host to outpace the device. This is generally a safe assumption but there are a some scenarios where that isnt' true.

If we change the Timer to only measure GPU time, with something like:

import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
import time
from tabulate import tabulate

from torch.nn.attention.flex_attention import (
    flex_attention,
    create_block_mask,
    create_mask,
)

torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)
print(f"Torch version: {torch.__version__}")

torch._dynamo.config.suppress_errors = True


# Utility: Timer context
# from torch._inductor.utils import do_bench_using_profiling
# def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
#     """Thin wrapper around do_bench_using_profiling"""
#     no_args = lambda: func(*args, **kwargs)
#     time = do_bench_using_profiling(no_args)
#     return time * 1e3
from transformer_nuggets.utils.benchmark import benchmark_cuda_function_in_microseconds

benchmark_fn = benchmark_cuda_function_in_microseconds


# Function: Causal mask
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
    results = []

    for batch_size in batch_sizes:
        for seq_len in sequence_lengths:
            # Setup inputs
            q = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16
            ).to("cuda")
            k = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16
            ).to("cuda")
            v = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.float16
            ).to("cuda")

            block_mask = create_block_mask(
                causal_mask,
                B=None,
                H=None,
                Q_LEN=seq_len,
                KV_LEN=seq_len,
                device="cuda",
            )
            mask = create_mask(causal_mask, None, None, seq_len, seq_len, device="cuda")

            # Benchmark flex_attention
            flex_times = []
            for _ in range(n_runs):
                flex_time = benchmark_fn(
                    flex_attention,
                    q,
                    k,
                    v,
                    score_mod=None,
                    block_mask=block_mask,
                    # warmup=True,
                )
                flex_times.append(flex_time)
            flex_avg_time = (sum(flex_times) / n_runs) * 1000  # Convert to ms

            # Benchmark scaled_dot_product_attention with mask
            sdpa_times = []
            with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
                for _ in range(n_runs):
                    sdpa_time = benchmark_fn(
                        scaled_dot_product_attention,
                        q,
                        k,
                        v,
                        attn_mask=mask,
                        # warmup=True,
                    )
                    sdpa_times.append(sdpa_time)
                sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000  # Convert to ms

            # Benchmark scaled_dot_product_attention with causal fa2
            causal_fa2_times = []
            with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
                for _ in range(n_runs):
                    causal_fa2_time = benchmark_fn(
                        scaled_dot_product_attention,
                        q,
                        k,
                        v,
                        is_causal=True,
                        # warmup=True,
                    )
                    causal_fa2_times.append(causal_fa2_time)
                causal_fa2_avg_time = (
                    sum(causal_fa2_times) / n_runs
                ) * 1000  # Convert to ms

            # Log results
            results.append(
                {
                    "Batch Size": batch_size,
                    "Seq Length": seq_len,
                    "FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
                    "SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
                    "FA2 Avg Time (ms)": f"{causal_fa2_avg_time:.2f}",
                }
            )

    return results


# Main script
if __name__ == "__main__":
    batch_sizes = [
        4,
        8,
    ]
    sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192, 16384]
    n_runs = 5  # Number of runs for averaging

    results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)

    # Generate table
    print("\n=== Benchmark Results ===")
    print(tabulate(results, headers="keys", tablefmt="grid"))

You will get a better measure of the Kernel time. Another subtlety is that you will be generating kernels with dynamic shapes that are slightly slower than those specialized on shapes. I set dynamic=False and then expanding the cache size

@drisspg drisspg added the question Further information is requested label Dec 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants