Skip to content

Commit

Permalink
Restore FlexAttention and FlashV3 backward (#2473)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2473

Reviewed By: xuzhao9

Differential Revision: D63543625

Pulled By: bertmaher

fbshipit-source-id: 1693e15875544bda0f5f6c69daa5597fffd80509
  • Loading branch information
bertmaher authored and facebook-github-bot committed Sep 28, 2024
1 parent 0f05015 commit 611bf70
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def flash_v3(
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
fn = lambda: flashattn_hopper_cuda.fwd(
q, k, v, None, self.sm_scale, self.causal
)
fn = lambda: flash_attn_v3(q, k, v, self.sm_scale, self.causal)
return fn

@register_benchmark()
Expand Down Expand Up @@ -360,6 +358,25 @@ def sdpa_flash_attention(q, k, v):
v,
)

@register_benchmark()
def flex_attention(self, q, k, v):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx

flex_attention = torch.compile(flex_attention, dynamic=False)

if self.causal:
B, H, S, D = q.shape
block_mask = create_block_mask(
causal_mask, B=None, H=None, Q_LEN=S, KV_LEN=S
)
else:
block_mask = None

return lambda: flex_attention(q, k, v, block_mask=block_mask)

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down

0 comments on commit 611bf70

Please sign in to comment.