Skip to content

Commit

Permalink
[Operator] Init NLL_LOSS (#269)
Browse files Browse the repository at this point in the history
* [Operator] Init NLL_LOSS

* [Bugfix] Fix NLLLoss accuracy test to put gradient on the same device

* [Chore] Apply minor modifications to NLLLoss

* [Chore] Register NLL_LOSS

* [Chore] Change register name of NLL_LOSS

* [Operator] Optimize nll_loss and achieve

* [Operator] update tiling config

* [Operator] set default tiling size for forward function

* [Operator] set block_size as 128 and simplify the code

* [Format] reformat

* [Format] reformat arg procession in libentry

* [Operator] fix bug in unit test of nll_loss

* [Operator] register operator as nll_loss_nd to support 3d tensor

* [Operator] reimplement nll_loss

* [Operator] fuse nll_loss forward kernels into one

* [Operator] add assertation and modify condition expression

* [Operator] optimize

---------

Co-authored-by: strongspoon <[email protected]>
  • Loading branch information
GwokHiujin and StrongSpoon authored Jan 15, 2025
1 parent 2c77d4e commit 08796d1
Show file tree
Hide file tree
Showing 10 changed files with 538 additions and 56 deletions.
16 changes: 16 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ def cross_entropy_loss_input_fn(shape, cur_dtype, device):
}


def nll_loss_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
target = torch.randint(0, shape[-1], (shape[0],), device=device)
yield inp, target
if Config.bench_level == BenchLevel.COMPREHENSIVE:
weight = torch.randn(shape[-1], dtype=cur_dtype, device=device)
yield inp, target, {"weight": weight, "ignore_index": 1, "reduction": "none"}


def cumsum_input_fn(shape, cur_dtype, device):
inp = generate_tensor_input(shape, cur_dtype, device)
yield inp, 1
Expand Down Expand Up @@ -162,6 +171,13 @@ def cumsum_input_fn(shape, cur_dtype, device):
),
],
),
pytest.param(
"nll_loss",
torch.nn.functional.nll_loss,
nll_loss_input_fn,
FLOAT_DTYPES,
marks=pytest.mark.NLLLoss,
),
],
)
def test_generic_reduction_benchmark(op_name, torch_op, input_fn, dtypes):
Expand Down
4 changes: 4 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("log_softmax.int", log_softmax, Autograd.enable),
("outer", outer, Autograd.enable),
("cross_entropy_loss", cross_entropy_loss, Autograd.enable),
("nll_loss_forward", nll_loss_forward, Autograd.disable),
("nll_loss_backward", nll_loss_backward, Autograd.disable),
("nll_loss2d_forward", nll_loss2d_forward, Autograd.disable),
("nll_loss2d_backward", nll_loss2d_backward, Autograd.disable),
("scatter.src", scatter, Autograd.disable),
("scatter.reduce", scatter, Autograd.disable),
("gather", gather, Autograd.disable),
Expand Down
10 changes: 10 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@
from .mv import mv
from .ne import ne, ne_scalar
from .neg import neg
from .nllloss import (
nll_loss2d_backward,
nll_loss2d_forward,
nll_loss_backward,
nll_loss_forward,
)
from .nonzero import nonzero
from .normal import normal_float_tensor, normal_tensor_float, normal_tensor_tensor
from .ones import ones
Expand Down Expand Up @@ -289,4 +295,8 @@
"logical_xor",
"logical_not",
"sort",
"nll_loss_forward",
"nll_loss_backward",
"nll_loss2d_forward",
"nll_loss2d_backward",
]
8 changes: 1 addition & 7 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,13 @@ def celoss_indices_smooth_bwd(


@libentry()
@triton.autotune(
configs=runtime.get_tuned_config("cross_entropy_loss_sum_and_scale"),
key=[
"N",
],
)
@triton.jit
def sum_and_scale(
inp_ptr,
out_ptr,
N,
scalebyw: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_N: tl.constexpr = 128,
scale=1.0,
mean_num=None,
):
Expand Down
16 changes: 10 additions & 6 deletions src/flag_gems/ops/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def backward(ctx, grad_outputs):
grad_inputs = torch.zeros(
(ctx.num_weights, grad_outputs.shape[-1]),
device=grad_outputs.device,
dtype=torch.float32
if grad_outputs.dtype is torch.bfloat16
else grad_outputs.dtype,
dtype=(
torch.float32
if grad_outputs.dtype is torch.bfloat16
else grad_outputs.dtype
),
)

if ctx.scale_grad_by_freq:
Expand Down Expand Up @@ -191,9 +193,11 @@ def backward(ctx, grad_outputs):
grad_inputs, indice_freq, ctx.num_weights, ctx.N, BLOCK_SIZE
)
return (
grad_inputs.to(torch.bfloat16)
if grad_outputs.dtype is torch.bfloat16
else grad_inputs,
(
grad_inputs.to(torch.bfloat16)
if grad_outputs.dtype is torch.bfloat16
else grad_inputs
),
None,
None,
None,
Expand Down
Loading

0 comments on commit 08796d1

Please sign in to comment.