From cd80d08e6a1b54019dbdee5dc82931f280812331 Mon Sep 17 00:00:00 2001 From: Shahar Elyashiv <73384589+shaharelys@users.noreply.github.com> Date: Sun, 7 Apr 2024 16:53:35 +0300 Subject: [PATCH] Added support `F.one_hot` (#128) --- thunder/executors/torchex.py | 2 ++ thunder/tests/opinfos.py | 27 +++++++++++++++++++++++++++ thunder/torch/__init__.py | 18 ++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7317b0587c..53f485c8b6 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1202,6 +1202,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: dropout = _register_torch_operation("dropout", module=torch.nn.functional) embedding = _register_torch_operation("embedding", module=torch.nn.functional) embedding_backward = _register_torch_operation("torch.ops.aten.embedding_backward", like=ltorch.embedding_backward) +one_hot = _register_torch_operation("one_hot", module=torch.nn.functional) group_norm = _register_torch_operation("group_norm", module=torch.nn.functional) interpolate = _register_torch_operation("interpolate", module=torch.nn.functional) linear = _register_torch_operation("linear", module=torch.nn.functional) @@ -1447,6 +1448,7 @@ def _pad_prim_impl( _register_implementation(ltorch.dropout, dropout, checker=_always_executable) _register_implementation(ltorch.embedding, embedding, checker=_always_executable) _register_implementation(ltorch.embedding_backward, embedding_backward, checker=_always_executable) +_register_implementation(ltorch.one_hot, one_hot, checker=_always_executable) _register_implementation(ltorch.group_norm, group_norm, checker=_always_executable) _register_implementation(ltorch.interpolate, interpolate, checker=_interpolate_checker) _register_implementation(ltorch.linear, linear, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 3e9c6b68ca..e95b1e72a2 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -6152,6 +6152,33 @@ def sample_generator(op, device, dtype, requires_grad, **kwargs): nn_ops.append(max_pool3d_opinfo) +def one_hot_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + test_shapes = [ + (0, 512), + (10,), + (5, 10), + (3, 5, 10), + ] + + max_value = 9 + for shape in test_shapes: + for num_classes in range(1, max_value + 1): + a = make(shape, low=0, high=num_classes - 1) # use non-negative integers + + yield SampleInput(a, num_classes=num_classes) + + +one_hot_opinfo = OpInfo( + ltorch.one_hot, + sample_input_generator=one_hot_sample_generator, + torch_reference=torch.nn.functional.one_hot, + dtypes=(datatypes.int64,), # akin to torch.long. F.one_hot expects input LongTensor +) +nn_ops.append(one_hot_opinfo) + + def group_norm_sample_generator(op, device, dtype, requires_grad, **kwargs): # NOTE: we set low/high to -+ 1 to avoid numerical issues with reduced float types. make = partial(make_tensor, low=-1, high=+1, device=device, dtype=dtype, requires_grad=requires_grad) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a408dc768a..bde8c091ba 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3489,6 +3489,24 @@ def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_fr return result +@torchsymbol(torch.nn.functional.one_hot, id="torch.nn.functional.one_hot", is_method=False) +def one_hot(a: TensorLike, /, num_classes: int) -> TensorLike: + # TODO: refactor when we're ready to support auto-inference for `num_classes = -1` using `.item()` + utils.check( + num_classes >= 1, + lambda: f"Currently supports only positive input for num_classes, got num_classes={num_classes}", + exception_type=NotImplementedError, + ) + # TODO: would we want to implement this check in the future? + # utils.check(a.any() >= 0, lambda f"input tensor should have non-negative values", exception_type=ValueError) + + canvas = zeros(*a.shape, num_classes, device=a.device, dtype=dtypes.int64) + index = a.unsqueeze(-1) + src = ones_like(index, device=a.device, dtype=dtypes.int64) + + return scatter_add(canvas, dim=-1, index=index, src=src) + + @torchsymbol(torch.group_norm, torch.nn.functional.group_norm, id="torch.nn.functional.group_norm", is_method=False) def group_norm( a: TensorProxy,