diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 591bedfa3a6f1..dfb214aa8a933 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -675,6 +675,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, + None, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 2019ed184e5a1..593dcbdea8428 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -104,6 +104,9 @@ def _attn_fwd_inner( ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, ): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): @@ -145,6 +148,8 @@ def _attn_fwd_inner( qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale if bias_ptr is not None: bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") @@ -196,7 +201,12 @@ def _attn_fwd_inner( l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij + + if USE_FP8: + p *= p_descale + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if bias_ptr is not None: @@ -292,18 +302,20 @@ def _attn_fwd_inner( # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config( - { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'], ) @triton.jit def attn_fwd( @@ -312,6 +324,12 @@ def attn_fwd( V, bias, sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, L, Out, stride_qz: tl.int64, @@ -354,6 +372,7 @@ def attn_fwd( BIAS_TYPE: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_FP8: tl.constexpr, ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -507,7 +526,12 @@ def attn_fwd( qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -562,6 +586,9 @@ def attn_fwd( ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head, + USE_FP8, + qk_scale, + p_descale, ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -608,8 +635,14 @@ def attn_fwd( ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head, + USE_FP8, + qk_scale, + p_descale, ) # epilogue + + if USE_FP8: + acc *= acc_scale acc = acc / l_i[:, None] if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) @@ -620,6 +653,8 @@ def attn_fwd( end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k + if USE_FP8: + acc *= o_descale acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: @@ -710,7 +745,25 @@ def forward( causal=False, sm_scale=1.0, bias=None, + fp8_scales=None, ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales + float8 = torch.float8_e4m3fnuz + def check_and_convert(t, scale): + finfo = torch.finfo(float8) + descale = 1.0 / scale + return (t * descale).clamp(min=finfo.min, max=finfo.max).to(float8) \ + if t.dtype != float8 else t + + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = o_scale = 1.0 + if o is None: o = torch.empty_like(q, dtype=v.dtype) @@ -773,12 +826,21 @@ def forward( else: bias_strides = (0, 0, 0, 0) + p_descale = 1.0 / p_scale + o_descale = 1.0 / o_scale + attn_fwd[grid]( q, k, v, bias, sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, None, o, *q_strides, @@ -803,6 +865,7 @@ def forward( BIAS_TYPE=0 if bias is None else 1, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, + USE_FP8 = use_fp8, ) ctx.grid = grid