Skip to content

Commit

Permalink
[Operator] Optimize nll_loss and achieve
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Jan 6, 2025
1 parent d36036c commit 81db135
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 57 deletions.
3 changes: 3 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def nll_loss_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
target = torch.randint(0, shape[-1], (shape[0],), device=device)
yield inp, target
if Config.bench_level == BenchLevel.COMPREHENSIVE:
weight = torch.randn(shape[-1], dtype=cur_dtype, device=device)
yield inp, target, {"weight": weight, "ignore_index": 1, "reduction": "none"}


def cumsum_input_fn(shape, cur_dtype, device):
Expand Down
109 changes: 52 additions & 57 deletions src/flag_gems/ops/nllloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import triton.language as tl

from ..utils import libentry
from .sum import sum
from .cross_entropy_loss import sum_and_scale


@libentry()
@triton.autotune(
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [256, 512, 1024]],
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [1, 16, 256]],
key=["N"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand All @@ -31,19 +31,18 @@ def nll_loss_2d_fwd_kernel(
mask_n = offsets_n < N

tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0)
ignore_mask = not (tgt == ignore_index)
mask_tgt = tgt < C
ignore_mask = not (tgt == ignore_index) and mask_n

w_ptrs = w_ptr + tgt
w_tgt = tl.load(w_ptrs, mask=mask_n, other=0).to(tl.float32)
tl.store(w_tgt_ptr + offsets_n, w_tgt, mask=(mask_n & ignore_mask))
if w_ptr is None:
w_tgt = 1
else:
w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
tl.store(w_tgt_ptr + offsets_n, w_tgt, mask=mask_n)

inp_tgt_ptrs = inp_ptr + offsets_n * C + tgt
inp_tgt = tl.load(inp_tgt_ptrs, mask=mask_n & mask_tgt, other=-float("inf")).to(
tl.float32
)
inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32)
out = inp_tgt * w_tgt * -1
tl.store(out_ptr + offsets_n, out, mask=mask_n & mask_tgt & ignore_mask)
tl.store(out_ptr + offsets_n, out, mask=mask_n)


@libentry()
Expand Down Expand Up @@ -93,11 +92,7 @@ def nll_loss_2d_bwd_kernel(

@libentry()
@triton.autotune(
configs=[
triton.Config({"BLOCK_C": c, "BLOCK_D": d}, num_warps=4)
for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=[triton.Config({"BLOCK_D": d}, num_warps=4) for d in [1, 4, 16]],
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand All @@ -111,29 +106,30 @@ def nll_loss_multi_fwd_kernel(
N,
C,
D,
BLOCK_C: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_n = tl.program_id(0)
pid_d = tl.program_id(1)
pid_n = tl.program_id(1)
pid_d = tl.program_id(0)
offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)

tgt_ptrs = tgt_ptr + pid_n * D + offset_d
mask_tgt = offset_d < D
tgt = tl.load(tgt_ptrs, mask=mask_tgt, other=0)
mask_d = offset_d < D
tgt = tl.load(tgt_ptrs, mask=mask_d, other=0)

ignore_mask = not (tgt == ignore_index)
ignore_mask = not (tgt == ignore_index) and mask_d

w_ptrs = w_ptr + tgt
w_tgt = tl.load(w_ptrs, mask=mask_tgt, other=0).to(tl.float32)
if w_ptr is None:
w_tgt = 1
else:
w_tgt = tl.load(w_ptr + tgt, mask=ignore_mask, other=0).to(tl.float32)
w_tgt_ptrs = w_tgt_ptr + pid_n * D + offset_d
tl.store(w_tgt_ptrs, w_tgt, mask=(mask_tgt & ignore_mask))
tl.store(w_tgt_ptrs, w_tgt, mask=mask_d)

inp_tgt_ptrs = inp_ptr + pid_n * C * D + tgt * D + offset_d
inp_tgt = tl.load(inp_tgt_ptrs, mask=mask_tgt, other=-float("inf")).to(tl.float32)
inp_tgt = tl.load(inp_tgt_ptrs, mask=ignore_mask, other=0).to(tl.float32)
out = inp_tgt * w_tgt * -1
out_ptrs = out_ptr + pid_n * D + offset_d
tl.store(out_ptrs, out, mask=(mask_tgt & ignore_mask))
tl.store(out_ptrs, out, mask=mask_d)


@libentry()
Expand Down Expand Up @@ -227,25 +223,17 @@ def forward(ctx, inp, target, weight, reduction, ignore_index):
axis = 0 if dim == 1 else 1
del shape[axis]

assert ((i >= 0 and i < C) for i in target), "Target is out of bounds"
assert list(target.shape) == shape, "Invalid target size"
assert inp.ndim >= 1, "Invalid input ndim"

if weight is None:
weight = torch.ones(
(C),
dtype=inp.dtype,
device=inp.device,
)

inp = inp.contiguous()
tgt = target.contiguous()
w = weight.contiguous()
out = torch.zeros(shape, dtype=torch.float32, device=inp.device)
w_tgt = torch.zeros(shape, dtype=torch.float32, device=inp.device)
w = None if weight is None else weight.contiguous()
out = torch.empty(shape, dtype=inp.dtype, device=inp.device)
w_tgt = torch.empty(shape, dtype=inp.dtype, device=inp.device)

if inp.ndim > 2:
grid = lambda meta: (N, triton.cdiv(D, meta["BLOCK_D"]))
grid = lambda meta: (triton.cdiv(D, meta["BLOCK_D"]), N)
with torch.cuda.device(inp.device):
nll_loss_multi_fwd_kernel[grid](
inp, tgt, w, w_tgt, out, ignore_index, N, C, D
Expand All @@ -257,24 +245,27 @@ def forward(ctx, inp, target, weight, reduction, ignore_index):
inp, tgt, w, w_tgt, out, ignore_index, N, C
)

ctx.save_for_backward(inp, tgt, w)
ctx.N = N
ctx.C = C
ctx.D = D
ctx.ignore_index = ignore_index
ctx.total_weight = 1
ctx.shape = shape

# redution: 0-None, 1-mean, 2-sum
if reduction == 0:
res = out.to(inp.dtype)
elif reduction == 1:
ctx.total_weight = sum(w_tgt).item()
res = sum(out).to(inp.dtype) / ctx.total_weight
else:
res = sum(out).to(inp.dtype)

return res
if reduction == 1:
res = torch.empty([], dtype=inp.dtype, device=inp.device)
wgt_sum = torch.empty([], dtype=inp.dtype, device=inp.device)
sum_and_scale[(1,)](out, res, N * D, True, scale=w_tgt, mean_num=wgt_sum)
out = res
elif reduction == 2:
res = torch.empty([], dtype=inp.dtype, device=inp.device)
sum_and_scale[(1,)](out, res, N * D, False)
out = res

if inp.requires_grad:
ctx.save_for_backward(inp, tgt, w)
ctx.N = N
ctx.C = C
ctx.D = D
ctx.ignore_index = ignore_index
ctx.total_weight = wgt_sum if reduction == 1 else 1
ctx.shape = shape

return out

@staticmethod
def backward(ctx, out_grad):
Expand All @@ -284,7 +275,11 @@ def backward(ctx, out_grad):
C = ctx.C
D = ctx.D
ignore_index = ctx.ignore_index
total_weight = ctx.total_weight
total_weight = (
ctx.total_weight.item()
if isinstance(ctx.total_weight, torch.Tensor)
else ctx.total_weight
)
shape = ctx.shape

out_grad = out_grad.broadcast_to(shape).contiguous()
Expand Down

0 comments on commit 81db135

Please sign in to comment.