Skip to content

Commit

Permalink
Diffusers attention script update triton2.1 (microsoft#4573)
Browse files Browse the repository at this point in the history
deepspeed/ops/transformer/inference/triton_ops.py updated from
https://github.com/openai/triton/blob/release/2.1.x/python/tutorials/06-fused-attention.py

Inference time (text to image) reduced 2.6 sec to 2.49 sec on A100 
model : stabilityai_stable-diffusion-2

@jithunnair-amd  @loadams @rraminen 
IS_CAUSAL = False gives same image output as not using deepspeed
inference engine ,
IS_CAUSAL = True gives noise as output

---------

Co-authored-by: Lev Kurilenko <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
4 people authored Nov 9, 2023
1 parent 8ad50cf commit 3b1cf1f
Showing 1 changed file with 64 additions and 44 deletions.
108 changes: 64 additions & 44 deletions deepspeed/ops/transformer/inference/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team
"""
Inspired by original Triton implementation:
https://github.com/openai/triton/blob/b244db06da24a87453a40ad35b085ee37dac3705/python/tutorials/06-fused-attention.py
https://github.com/openai/triton/blob/release/2.1.x/python/tutorials/06-fused-attention.py
"""

import torch
Expand Down Expand Up @@ -44,59 +44,79 @@ def _fwd_kernel(
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
K_block_ptr = tl.make_block_ptr(base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1))
V_block_ptr = tl.make_block_ptr(base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
off_v = off_hz * stride_vh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
for start_n in range(0, N_CTX, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)

qk = tl.dot(q, tl.trans(k))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
p = p.to(tl.float16)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
lo = 0
#hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
hi = N_CTX
#hi = (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
#if IS_CAUSAL:
#qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
qk += tl.dot(q, k)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
#l_ptrs = L + off_hz * N_CTX + offs_m
#tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
tl.store(O_block_ptr, acc.to(tl.float16))


class triton_flash_attn(torch.nn.Module):
Expand Down

0 comments on commit 3b1cf1f

Please sign in to comment.