Skip to content

Commit

Permalink
Added support F.one_hot (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharelys authored Apr 7, 2024
1 parent aef1f4c commit cd80d08
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 0 deletions.
2 changes: 2 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
dropout = _register_torch_operation("dropout", module=torch.nn.functional)
embedding = _register_torch_operation("embedding", module=torch.nn.functional)
embedding_backward = _register_torch_operation("torch.ops.aten.embedding_backward", like=ltorch.embedding_backward)
one_hot = _register_torch_operation("one_hot", module=torch.nn.functional)
group_norm = _register_torch_operation("group_norm", module=torch.nn.functional)
interpolate = _register_torch_operation("interpolate", module=torch.nn.functional)
linear = _register_torch_operation("linear", module=torch.nn.functional)
Expand Down Expand Up @@ -1447,6 +1448,7 @@ def _pad_prim_impl(
_register_implementation(ltorch.dropout, dropout, checker=_always_executable)
_register_implementation(ltorch.embedding, embedding, checker=_always_executable)
_register_implementation(ltorch.embedding_backward, embedding_backward, checker=_always_executable)
_register_implementation(ltorch.one_hot, one_hot, checker=_always_executable)
_register_implementation(ltorch.group_norm, group_norm, checker=_always_executable)
_register_implementation(ltorch.interpolate, interpolate, checker=_interpolate_checker)
_register_implementation(ltorch.linear, linear, checker=_always_executable)
Expand Down
27 changes: 27 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6152,6 +6152,33 @@ def sample_generator(op, device, dtype, requires_grad, **kwargs):
nn_ops.append(max_pool3d_opinfo)


def one_hot_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

test_shapes = [
(0, 512),
(10,),
(5, 10),
(3, 5, 10),
]

max_value = 9
for shape in test_shapes:
for num_classes in range(1, max_value + 1):
a = make(shape, low=0, high=num_classes - 1) # use non-negative integers

yield SampleInput(a, num_classes=num_classes)


one_hot_opinfo = OpInfo(
ltorch.one_hot,
sample_input_generator=one_hot_sample_generator,
torch_reference=torch.nn.functional.one_hot,
dtypes=(datatypes.int64,), # akin to torch.long. F.one_hot expects input LongTensor
)
nn_ops.append(one_hot_opinfo)


def group_norm_sample_generator(op, device, dtype, requires_grad, **kwargs):
# NOTE: we set low/high to -+ 1 to avoid numerical issues with reduced float types.
make = partial(make_tensor, low=-1, high=+1, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down
18 changes: 18 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3489,6 +3489,24 @@ def embedding_backward(grad, indices, num_weights, padding_idx, scale_grad_by_fr
return result


@torchsymbol(torch.nn.functional.one_hot, id="torch.nn.functional.one_hot", is_method=False)
def one_hot(a: TensorLike, /, num_classes: int) -> TensorLike:
# TODO: refactor when we're ready to support auto-inference for `num_classes = -1` using `.item()`
utils.check(
num_classes >= 1,
lambda: f"Currently supports only positive input for num_classes, got num_classes={num_classes}",
exception_type=NotImplementedError,
)
# TODO: would we want to implement this check in the future?
# utils.check(a.any() >= 0, lambda f"input tensor should have non-negative values", exception_type=ValueError)

canvas = zeros(*a.shape, num_classes, device=a.device, dtype=dtypes.int64)
index = a.unsqueeze(-1)
src = ones_like(index, device=a.device, dtype=dtypes.int64)

return scatter_add(canvas, dim=-1, index=index, src=src)


@torchsymbol(torch.group_norm, torch.nn.functional.group_norm, id="torch.nn.functional.group_norm", is_method=False)
def group_norm(
a: TensorProxy,
Expand Down

0 comments on commit cd80d08

Please sign in to comment.