Skip to content

Commit

Permalink
tl.dot(a,b, trans_b=True) is not supported by triton2.0+ , updating t…
Browse files Browse the repository at this point in the history
…his api (microsoft#4541)

Stable Diffusion Inference with deepspeed inference engine works with
this update with triton2.1 on A100

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
bmedishe and loadams authored Nov 8, 2023
1 parent 00757a1 commit 217b15d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 9 deletions.
9 changes: 1 addition & 8 deletions deepspeed/ops/transformer/inference/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def _fwd_kernel(
K,
V,
sm_scale,
TMP,
Out,
stride_qz,
stride_qh,
Expand Down Expand Up @@ -57,7 +56,6 @@ def _fwd_kernel(
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
t_ptrs = TMP + off_hz * N_CTX + offs_m
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
Expand All @@ -69,8 +67,7 @@ def _fwd_kernel(
# -- compute qk ----
k = tl.load(k_ptrs + start_n * stride_kn)

qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, trans_b=True)
qk = tl.dot(q, tl.trans(k))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
Expand All @@ -87,8 +84,6 @@ def _fwd_kernel(
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
tl.store(t_ptrs, acc_scale)
acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs + start_n * stride_vk)
Expand All @@ -115,15 +110,13 @@ def forward(self, q, k, v, sm_scale, block_128=True):
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8

_fwd_kernel[grid](
q,
k,
v,
sm_scale,
tmp,
o,
q.stride(0),
q.stride(1),
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-sd.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
diffusers
triton==2.0.0.dev20221202
triton

0 comments on commit 217b15d

Please sign in to comment.