Skip to content

Commit

Permalink
Eliminate ternary if statement (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
groenenboomj authored Feb 13, 2024
1 parent 328b4dd commit d6f14d3
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MetaData():

def __init__(self, sm_scale=1.0):
self.sm_scale = sm_scale

def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.cu_seqlens_q = cu_seqlens_q
Expand All @@ -75,7 +75,7 @@ def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k):

self.bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
self.bias_type = bias_type

def need_causal(self):
self.causal = True

Expand Down Expand Up @@ -327,7 +327,11 @@ def attn_fwd(
cu_seqlens_k_start = 0
seqlen_q = max_seqlens_q
seqlen_k = max_seqlens_k
off_h_k = off_h_q % hk if is_mqa else off_h_q
#off_h_k = off_h_q % hk if is_mqa else off_h_q
if is_mqa:
off_h_k = off_h_q % hk
else:
off_h_k = off_h_q
need_padding = False
extra_tokens_n = 0
if seqlen_k < BLOCK_N:
Expand Down Expand Up @@ -804,20 +808,20 @@ def forward(ctx, q, k, v, o, metadata):
philox_offset = 0x1D4B42

if metadata.bias_type != 0:
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1),
metadata.bias.stride(2), metadata.bias.stride(3))
else:
bias_strides = (0,0,0,0)

attn_fwd[grid](
q, k, v, metadata.bias, metadata.sm_scale, M, o,
*q_strides, *k_strides, *v_strides, *o_strides, *bias_strides,
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
metadata.cu_seqlens_q, metadata.cu_seqlens_k,
dropout_p=metadata.dropout_p,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
max_seqlens_q=metadata.max_seqlens_q,
max_seqlens_q=metadata.max_seqlens_q,
max_seqlens_k=metadata.max_seqlens_k,
VARLEN=metadata.varlen,
hq=nheads_q, hk=nheads_k,
Expand Down Expand Up @@ -982,7 +986,7 @@ def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, use_bias, bias_type, qseqlen_not_eq

def varlen_input_helper(Z, HQ, HK, N_CTX, D_HEAD, dtype):
torch.manual_seed(20)

# Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs
max_seqlens = N_CTX // Z
seqlens_q = torch.randint(1, max_seqlens + 1, (Z,), dtype=torch.int32)
Expand Down Expand Up @@ -1190,7 +1194,7 @@ def bench_flash_attention(
input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX)
else:
bias = None

# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True
Expand Down Expand Up @@ -1290,4 +1294,3 @@ def bench_varlen_flash_attention(
return total_flops / ms * 1e-9

bench_varlen_flash_attention.run(save_path=".", print_data=True)

0 comments on commit d6f14d3

Please sign in to comment.