Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Init NLL_LOSS #269

Merged
merged 17 commits into from
Jan 15, 2025
Merged

[Operator] Init NLL_LOSS #269

merged 17 commits into from
Jan 15, 2025

Conversation

GwokHiujin
Copy link
Collaborator

A basic implementation of NLL_LOSS has been pushed.

Based on the performance testing results summarized earlier, we believe that using the gather operation would lead to a more efficient implementation (by observing the output results of latency, it seems this is also how torch does it), and we will push forward with this optimization.

@tongxin tongxin self-assigned this Nov 10, 2024
@tongxin tongxin requested a review from StrongSpoon November 10, 2024 15:06
tongxin
tongxin previously approved these changes Nov 15, 2024
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -135,6 +141,13 @@ def cumsum_input_fn(shape, cur_dtype, device):
FLOAT_DTYPES + INT_DTYPES,
marks=pytest.mark.cumsum,
),
pytest.param(
"nll_loss",
torch.nn.NLLLoss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NLLLoss is a class. Can we use it as the reference function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I've updated it to torch.nn.functional.nll_loss.

src/flag_gems/ops/nllloss.py Show resolved Hide resolved
src/flag_gems/ops/nllloss.py Outdated Show resolved Hide resolved
src/flag_gems/ops/nllloss.py Show resolved Hide resolved

if weight is None:
weight = torch.ones(
[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use tuple

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

src/flag_gems/ops/nllloss.py Outdated Show resolved Hide resolved
src/flag_gems/ops/nllloss.py Outdated Show resolved Hide resolved
tl.store(inp_grad_ptrs, inp_grad.to(tl.float32), mask=(inp_mask & ignore_mask))


class NLLLoss(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is intended be used as substitute for nll_loss whereas NLLLoss is already taken as the nn module name. We should avoid the name confusion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I've upated the class name.

if reduction == 0:
res = out.to(inp.dtype)
elif reduction == 1:
ctx.total_weight = sum(w_tgt).item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also add dim= args to avoid confusion?

tongxin
tongxin previously approved these changes Dec 12, 2024
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

src/flag_gems/ops/nllloss.py Outdated Show resolved Hide resolved
@StrongSpoon
Copy link
Collaborator

could you provide the performance data?

@GwokHiujin
Copy link
Collaborator Author

could you provide the performance data?

A general result is as below:

Operator: nll_loss  Performance Test (dtype=torch.float16, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.008192            0.430080               0.019          [torch.Size([64, 64]), torch.Size([64])]
SUCCESS               0.014336            0.626688               0.023          [torch.Size([256, 256]), torch.Size([256])]
SUCCESS               0.039936            1.439744               0.028          [torch.Size([1024, 1024]), torch.Size([1024])]
SUCCESS               0.146432            4.753408               0.031          [torch.Size([4096, 4096]), torch.Size([4096])]
SUCCESS               0.039936            1.436672               0.028          [torch.Size([1024, 65536]), torch.Size([1024])]
SUCCESS               0.037888            1.439744               0.026          [torch.Size([1024, 1]), torch.Size([1024])]
SUCCESS               0.034816            1.430528               0.024          [torch.Size([1024, 16]), torch.Size([1024])]
SUCCESS               0.043008            1.440768               0.030          [torch.Size([1024, 256]), torch.Size([1024])]
SUCCESS               0.039936            1.432576               0.028          [torch.Size([1024, 4096]), torch.Size([1024])]


Operator: nll_loss  Performance Test (dtype=torch.float32, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.008192            0.413696               0.020          [torch.Size([64, 64]), torch.Size([64])]
SUCCESS               0.014336            0.601088               0.024          [torch.Size([256, 256]), torch.Size([256])]
SUCCESS               0.039936            1.414144               0.028          [torch.Size([1024, 1024]), torch.Size([1024])]
SUCCESS               0.146432            4.697088               0.031          [torch.Size([4096, 4096]), torch.Size([4096])]
SUCCESS               0.038912            1.420288               0.027          [torch.Size([1024, 65536]), torch.Size([1024])]
SUCCESS               0.036864            1.413120               0.026          [torch.Size([1024, 1]), torch.Size([1024])]
SUCCESS               0.034816            1.412096               0.025          [torch.Size([1024, 16]), torch.Size([1024])]
SUCCESS               0.043008            1.417216               0.030          [torch.Size([1024, 256]), torch.Size([1024])]
SUCCESS               0.038912            1.411072               0.028          [torch.Size([1024, 4096]), torch.Size([1024])]


Operator: nll_loss  Performance Test (dtype=torch.bfloat16, mode=cuda, level=comprehensive)
Size         Torch Latency (ms)    Gems Latency (ms)         Gems Speedup         Size Detail
------------------------------------------------------------------------------------------
SUCCESS               0.008192            0.431104               0.019          [torch.Size([64, 64]), torch.Size([64])]
SUCCESS               0.014336            0.621568               0.023          [torch.Size([256, 256]), torch.Size([256])]
SUCCESS               0.039936            1.448960               0.028          [torch.Size([1024, 1024]), torch.Size([1024])]
SUCCESS               0.146432            4.698112               0.031          [torch.Size([4096, 4096]), torch.Size([4096])]
SUCCESS               0.039936            1.448960               0.028          [torch.Size([1024, 65536]), torch.Size([1024])]
SUCCESS               0.037888            1.452032               0.026          [torch.Size([1024, 1]), torch.Size([1024])]
SUCCESS               0.034816            1.444864               0.024          [torch.Size([1024, 16]), torch.Size([1024])]
SUCCESS               0.043008            1.442816               0.030          [torch.Size([1024, 256]), torch.Size([1024])]
SUCCESS               0.038912            1.446912               0.027          [torch.Size([1024, 4096]), torch.Size([1024])]

It is somewhat poor now. As mentioned before, we may use gather later for optimization.

inp, tgt, w, w_tgt, out, ignore_index, N, C
)

ctx.save_for_backward(inp, tgt, w)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only saving tensors and variables when input requires gradient might decrease the cost.

res = out.to(inp.dtype)
elif reduction == 1:
ctx.total_weight = sum(w_tgt).item()
res = sum(out).to(inp.dtype) / ctx.total_weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest fusing sum into forward kernel, referencing to cross_entropy_loss.

@StrongSpoon
Copy link
Collaborator

performance after optimization:
截屏2025-01-06 14 26 09



@libentry()
@triton.autotune(
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [256, 512, 1024]],
configs=[triton.Config({"BLOCK_N": n}, num_warps=4) for n in [1, 16, 256]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BLOCK_N varies so much?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think BLOCK = 128 is good pick.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about [1, 4, 32, 128]?

for c in [256, 512, 1024]
for d in [1, 4, 16]
],
configs=[triton.Config({"BLOCK_D": d}, num_warps=4) for d in [1, 4, 16]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering this as gather/scatter like kernel, BLOCK = 128 should be good enough.

tongxin
tongxin previously approved these changes Jan 9, 2025
Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

tongxin
tongxin previously approved these changes Jan 9, 2025
@StrongSpoon
Copy link
Collaborator

the latest performance:
截屏2025-01-15 09 43 19

Comment on lines 49 to 55
tl.atomic_add(out_ptr + 2, 1, sem="release") # counter
counter = tl.load(out_ptr + 2)
total_out = tl.load(out_ptr)
total_wgt = tl.load(out_ptr + 1)
tl.store(
out_ptr + 3, total_out / total_wgt, mask=(counter == tl.num_programs(0))
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safer to use a stronger memory order for the counter update in line 49. And then we're ensured to have only one CTA to do the rest.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

release is enough here. it's safe even if there are more than one CTAs satisfying the condition.

Copy link
Contributor

@tongxin tongxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now.

@StrongSpoon StrongSpoon merged commit 08796d1 into master Jan 15, 2025
8 of 9 checks passed
@StrongSpoon StrongSpoon deleted the nll_loss branch January 15, 2025 09:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants