Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validating CrossEntropyLoss Performance #278

Closed
kevinstephano opened this issue May 4, 2023 · 3 comments
Closed

Validating CrossEntropyLoss Performance #278

kevinstephano opened this issue May 4, 2023 · 3 comments
Assignees

Comments

@kevinstephano
Copy link
Collaborator

kevinstephano commented May 4, 2023

I made this code snippet to show perf of CrossEntropyLoss.

import torch

class MyLoss(torch.nn.Module):
    def __init__(self):
        super(MyLoss, self).__init__()
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        out = self.loss(inputs, targets)
        return out

inputs = torch.randn(8192, 32768, device='cuda')
targets = torch.randint(32767, (8192,), device='cuda')

model = torch.compile(MyLoss())

for _ in range(5):
    out = model(inputs, targets)

Test command:

nsys nvprof --print-gpu-trace python my_loss.py

Sample output on A100:

Tensor Sizes:
inputs = [8192, 32768]
targets = [8192[

Kernel1: 1.222ms
Kernel2: 37.1us

4202242471        1222152    1059  8192     1     1   256     1     1       39         0.000         0.001                                                     NVIDIA A100 80GB PCIe (0)    1     7  triton__0d1d2d3d4d                                                                                  
4203465679          37121    1072     1     1     1   256     1     1      184         0.000         0.008                                                     NVIDIA A100 80GB PCIe (0)    1     7  triton__0d1d2d3d4d56d
@csarofeen
Copy link
Collaborator

What's the effective bandwidth of the kernels?

@naoyam
Copy link
Collaborator

naoyam commented May 8, 2023

Kernel1 is around 880 GB/s. The other one is just a few GB/s and is negligible.

naoyam added a commit that referenced this issue May 12, 2023
For #278, our schedulers currently can't fuse them into a single kernel,
so a segmentation needs to happen. Currently, it's segmented between
softmax and take_along_axis, just because of the ordering of the
segmenter. However, we want the take_along_axis op to be fused together
with the preceding softmax since then the temporary output from the
first segment would be much smaller, reducing gmem access overhead. In
the case of `[8192, 32768]`, the (logical) I/O cost would be `1 /
32768`.

This PR introduces a simple mechanism to allow preferred fusions in the
segmentation steps. Currently, there's only preference of select-like
ops with producers. "Select-like" here also includes index_select and
torch_gather to size-one domains. In those ops, it's guaranteed that the
size of the consumer tensor is no larger than the lookup tensor, so it
makes sense to fuse them with producers.

Currently, it's only tested with the cross-entropy loss case. The
overall segmentation algorithm would need to go through significant
refactoring, so I don't think making this interface super robust is
worth doing at this moment, and it's likely to be redesigned. For now,
this is very important for the cross-entropy performance.
@kevinstephano
Copy link
Collaborator Author

Closing, old.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants