From 3b1cf1fdc44017e929ce7c12b548c3686e52a528 Mon Sep 17 00:00:00 2001 From: Bhavya Medishetty <86327865+bmedishe@users.noreply.github.com> Date: Wed, 8 Nov 2023 18:29:44 -0600 Subject: [PATCH] Diffusers attention script update triton2.1 (#4573) deepspeed/ops/transformer/inference/triton_ops.py updated from https://github.com/openai/triton/blob/release/2.1.x/python/tutorials/06-fused-attention.py Inference time (text to image) reduced 2.6 sec to 2.49 sec on A100 model : stabilityai_stable-diffusion-2 @jithunnair-amd @loadams @rraminen IS_CAUSAL = False gives same image output as not using deepspeed inference engine , IS_CAUSAL = True gives noise as output --------- Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com> Co-authored-by: Lev Kurilenko Co-authored-by: Michael Wyatt --- .../ops/transformer/inference/triton_ops.py | 108 +++++++++++------- 1 file changed, 64 insertions(+), 44 deletions(-) diff --git a/deepspeed/ops/transformer/inference/triton_ops.py b/deepspeed/ops/transformer/inference/triton_ops.py index 56e98f72a07c..f98f45ef638e 100644 --- a/deepspeed/ops/transformer/inference/triton_ops.py +++ b/deepspeed/ops/transformer/inference/triton_ops.py @@ -4,7 +4,7 @@ # DeepSpeed Team """ Inspired by original Triton implementation: -https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py +https://github.com/openai/triton/blob/release/2.1.x/python/tutorials/06-fused-attention.py """ import torch @@ -44,59 +44,79 @@ def _fwd_kernel( ): start_m = tl.program_id(0) off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + K_block_ptr = tl.make_block_ptr(base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1)) + V_block_ptr = tl.make_block_ptr(base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0)) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk - off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk - off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk - # Initialize pointers to Q, K, V - q_ptrs = Q + off_q - k_ptrs = K + off_k - v_ptrs = V + off_v # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout - q = tl.load(q_ptrs) + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(tl.float16) # loop over k, v and update accumulator - for start_n in range(0, N_CTX, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + start_n * stride_kn) - - qk = tl.dot(q, tl.trans(k)) - qk *= sm_scale - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + start_n * stride_vk) - p = p.to(tl.float16) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new + lo = 0 + #hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + hi = N_CTX + #hi = (start_m + 1) * BLOCK_M + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + #if IS_CAUSAL: + #qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.float16), v) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) m_i = m_i_new - # initialize pointers to output - offs_n = tl.arange(0, BLOCK_DMODEL) - off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on - out_ptrs = Out + off_o - tl.store(out_ptrs, acc) + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + #l_ptrs = L + off_hz * N_CTX + offs_m + #tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0)) + tl.store(O_block_ptr, acc.to(tl.float16)) class triton_flash_attn(torch.nn.Module):