Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Init NLL_LOSS #269

Merged
merged 17 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ def cross_entropy_loss_input_fn(shape, cur_dtype, device):
}


def nll_loss_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
target = torch.randint(0, shape[-1], (shape[0],), device=device)
yield inp, target
if Config.bench_level == BenchLevel.COMPREHENSIVE:
weight = torch.randn(shape[-1], dtype=cur_dtype, device=device)
yield inp, target, {"weight": weight, "ignore_index": 1, "reduction": "none"}


def cumsum_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
yield inp, 1
Expand Down Expand Up @@ -162,6 +171,13 @@ def cumsum_input_fn(shape, cur_dtype, device):
),
],
),
pytest.param(
"nll_loss",
torch.nn.functional.nll_loss,
nll_loss_input_fn,
FLOAT_DTYPES,
marks=pytest.mark.NLLLoss,
),
],
)
def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes):
Expand Down
4 changes: 4 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("log_softmax.int", log_softmax, Autograd.enable),
("outer", outer, Autograd.enable),
("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
("nll_loss_forward", nll_loss_forward, Autograd.disable),
("nll_loss_backward", nll_loss_backward, Autograd.disable),
("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable),
("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable),
("scatter.src", scatter, Autograd.disable),
("scatter.reduce", scatter, Autograd.disable),
("gather", gather, Autograd.disable),
Expand Down
10 changes: 10 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@
from .mv import mv
from .ne import ne, ne_scalar
from .neg import neg
from .nllloss import (
nll_loss2d_backward,
nll_loss2d_forward,
nll_loss_backward,
nll_loss_forward,
)
from .nonzero import nonzero
from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor
from .ones import ones
Expand Down Expand Up @@ -289,4 +295,8 @@
"logical_xor",
"logical_not",
"sort",
"nll_loss_forward",
"nll_loss_backward",
"nll_loss2d_forward",
"nll_loss2d_backward",
]
8 changes: 1 addition & 7 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,13 @@ def celoss_indices_smooth_bwd(


@libentry()
@triton.autotune(
configs=runtime.get_tuned_config("cross_entropy_loss_sum_and_scale"),
key=[
"N",
],
)
@triton.jit
def sum_and_scale(
inp_ptr,
out_ptr,
N,
scalebyw: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_N: tl.constexpr = 128,
scale=1.0,
mean_num=None,
):
Expand Down
16 changes: 10 additions & 6 deletions src/flag_gems/ops/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def backward(ctx, grad_outputs):
grad_inputs = torch.zeros(
(ctx.num_weights, grad_outputs.shape[-1]),
device=grad_outputs.device,
dtype=torch.float32
if grad_outputs.dtype is torch.bfloat16
else grad_outputs.dtype,
dtype=(
torch.float32
if grad_outputs.dtype is torch.bfloat16
else grad_outputs.dtype
),
)

if ctx.scale_grad_by_freq:
Expand Down Expand Up @@ -191,9 +193,11 @@ def backward(ctx, grad_outputs):
grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE
)
return (
grad_inputs.to(torch.bfloat16)
if grad_outputs.dtype is torch.bfloat16
else grad_inputs,
(
grad_inputs.to(torch.bfloat16)
if grad_outputs.dtype is torch.bfloat16
else grad_inputs
),
None,
None,
None,
Expand Down
Loading
Loading