Skip to content

Commit

Permalink
[Operator] add assertation and modify condition expression
Browse files Browse the repository at this point in the history
  • Loading branch information
StrongSpoon committed Jan 15, 2025
1 parent 7b65d50 commit ed84cc8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/flag_gems/ops/nllloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def nll_loss_forward_kernel(
mask_n = offsets_n < N

tgt = tl.load(tgt_ptr + offsets_n, mask=mask_n, other=0)
assert tgt >= 0 and tgt < C, "Invalid target value"
ignore_mask = not (tgt == ignore_index) and mask_n

if wgt_ptr is None:
Expand All @@ -50,9 +51,8 @@ def nll_loss_forward_kernel(
counter = tl.load(out_ptr + 2)
total_out = tl.load(out_ptr)
total_wgt = tl.load(out_ptr + 1)
tl.store(
out_ptr + 3, total_out / total_wgt, mask=(counter == tl.num_programs(0))
)
if counter == tl.num_programs(0):
tl.store(out_ptr + 3, total_out / total_wgt)
# sum
else:
total_out = tl.sum(out)
Expand Down Expand Up @@ -124,6 +124,7 @@ def nll_loss2d_forward_kernel(

tgt_ptrs = tgt_ptr + offset_n * D + offset_d
tgt = tl.load(tgt_ptrs, mask=mask_block, other=0)
assert tgt >= 0 and tgt < C, "Invalid target value"
ignore_mask = not (tgt == ignore_index) and mask_block

if wgt_ptr is None:
Expand All @@ -149,9 +150,8 @@ def nll_loss2d_forward_kernel(
counter = tl.load(out_ptr + 2)
total_out = tl.load(out_ptr)
total_wgt = tl.load(out_ptr + 1)
tl.store(
out_ptr + 3, total_out / total_wgt, mask=(counter == tl.num_programs(0))
)
if counter == tl.num_programs(0):
tl.store(out_ptr + 3, total_out / total_wgt)
# sum
else:
total_out = tl.sum(out)
Expand Down

0 comments on commit ed84cc8

Please sign in to comment.