-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Init NLL_LOSS #269
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
benchmark/test_reduction_perf.py
Outdated
@@ -135,6 +141,13 @@ def cumsum_input_fn(shape, cur_dtype, device): | |||
FLOAT_DTYPES + INT_DTYPES, | |||
marks=pytest.mark.cumsum, | |||
), | |||
pytest.param( | |||
"nll_loss", | |||
torch.nn.NLLLoss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NLLLoss is a class. Can we use it as the reference function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I've updated it to torch.nn.functional.nll_loss
.
src/flag_gems/ops/nllloss.py
Outdated
|
||
if weight is None: | ||
weight = torch.ones( | ||
[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/flag_gems/ops/nllloss.py
Outdated
tl.store(inp_grad_ptrs, inp_grad.to(tl.float32), mask=(inp_mask & ignore_mask)) | ||
|
||
|
||
class NLLLoss(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is intended be used as substitute for nll_loss
whereas NLLLoss is already taken as the nn module name. We should avoid the name confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I've upated the class name.
src/flag_gems/ops/nllloss.py
Outdated
if reduction == 0: | ||
res = out.to(inp.dtype) | ||
elif reduction == 1: | ||
ctx.total_weight = sum(w_tgt).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also add dim=
args to avoid confusion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
could you provide the performance data? |
A general result is as below:
It is somewhat poor now. As mentioned before, we may use |
src/flag_gems/ops/nllloss.py
Outdated
inp, tgt, w, w_tgt, out, ignore_index, N, C | ||
) | ||
|
||
ctx.save_for_backward(inp, tgt, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only saving tensors and variables when input requires gradient might decrease the cost.
src/flag_gems/ops/nllloss.py
Outdated
res = out.to(inp.dtype) | ||
elif reduction == 1: | ||
ctx.total_weight = sum(w_tgt).item() | ||
res = sum(out).to(inp.dtype) / ctx.total_weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest fusing sum into forward kernel, referencing to cross_entropy_loss.
src/flag_gems/ops/nllloss.py
Outdated
|
||
|
||
@libentry() | ||
@triton.autotune( | ||
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [256, 512, 1024]], | ||
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [1, 16, 256]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why BLOCK_N varies so much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think BLOCK = 128 is good pick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about [1, 4, 32, 128]?
src/flag_gems/ops/nllloss.py
Outdated
for c in [256, 512, 1024] | ||
for d in [1, 4, 16] | ||
], | ||
configs=[triton.Config({"BLOCK_D": d}, num_warps=4) for d in [1, 4, 16]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering this as gather/scatter like kernel, BLOCK = 128 should be good enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
src/flag_gems/ops/nllloss.py
Outdated
tl.atomic_add(out_ptr + 2, 1, sem="release") # counter | ||
counter = tl.load(out_ptr + 2) | ||
total_out = tl.load(out_ptr) | ||
total_wgt = tl.load(out_ptr + 1) | ||
tl.store( | ||
out_ptr + 3, total_out / total_wgt, mask=(counter == tl.num_programs(0)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's safer to use a stronger memory order for the counter update in line 49. And then we're ensured to have only one CTA to do the rest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
release is enough here. it's safe even if there are more than one CTAs satisfying the condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now.
A basic implementation of
NLL_LOSS
has been pushed.Based on the performance testing results summarized earlier, we believe that using the
gather
operation would lead to a more efficient implementation (by observing the output results of latency, it seems this is also how torch does it), and we will push forward with this optimization.