diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 593dcbdea8428..1741091e0958e 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -751,11 +751,15 @@ def forward( use_fp8 = True (q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales float8 = torch.float8_e4m3fnuz + def check_and_convert(t, scale): - finfo = torch.finfo(float8) - descale = 1.0 / scale - return (t * descale).clamp(min=finfo.min, max=finfo.max).to(float8) \ - if t.dtype != float8 else t + if t.dtype != float8: + finfo = torch.finfo(float8) + descale = 1.0 / scale + ts = (t * descale).clamp(min=finfo.min, max=finfo.max) + return ts.to(float8) + else: + return t q = check_and_convert(q, q_scale) k = check_and_convert(k, k_scale) @@ -865,7 +869,7 @@ def check_and_convert(t, scale): BIAS_TYPE=0 if bias is None else 1, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, - USE_FP8 = use_fp8, + USE_FP8=use_fp8, ) ctx.grid = grid