Skip to content

Commit

Permalink
Add support for torch.gather (#252)
Browse files Browse the repository at this point in the history
Implements gather prim, adds support for torch.gather, and implements its grad transform.

Fixes #223.
  • Loading branch information
rdspring1 authored Apr 24, 2024
1 parent 0e0f305 commit 34d21e1
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 1 deletion.
6 changes: 6 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,12 @@ def take_along_axis(a: TensorProxy, /, indices: TensorProxy, dim: int) -> Tensor
return prims.take_along_axis(a, indices, dim)


@clangop()
def gather(a: TensorProxy, /, indices: TensorProxy, dim: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
return prims.gather(a, indices, dim)


@clangop()
def scatter_add(a: TensorProxy, /, indices: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
Expand Down
24 changes: 24 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class PrimIDs(Enum):
ARGMIN = auto()
TOPK = auto()
# Scatter and gather prims (Experimental!)
GATHER = auto()
INDEX_ADD = auto()
INDEX_PUT = auto()
SCATTER_ADD = auto()
Expand Down Expand Up @@ -3082,6 +3083,29 @@ def take_along_axis_meta(a: TensorProxy, /, index: TensorProxy, dim: int) -> Ten
take_along_axis = make_prim(PrimIDs.TAKE_ALONG_AXIS, "take_along_axis", meta=take_along_axis_meta)


def gather_meta(a: TensorProxy, /, index: TensorProxy, dim: int) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(index, TensorProxy)
utils.check_type(dim, int)
utils.check_same_device(a, index)
utils.check(utils.is_integer_dtype(index.dtype), lambda: f"index dtype={index.dtype} was not an integer dtype")
utils.check(
index.ndim == a.ndim, lambda: f"Expected index (rank={index.ndim}) to have the same rank as a (rank={a.ndim})"
)
utils.validate_idx(a.ndim, dim)

for idx, l in enumerate(index.shape):
if idx != dim:
utils.check(
index.shape[idx] <= a.shape[idx],
lambda: f"Expected 'index' size on all dimensions to be <= 'a', except `dim`. Found dim {idx}, where 'index' has {index.shape[idx]} and 'a' has {a.shape[idx]}",
)
return TensorProxy(like=a, shape=index.shape)


gather = make_prim(PrimIDs.GATHER, "gather", meta=gather_meta)


def scatter_add_meta(a: TensorProxy, /, index: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(index, TensorProxy)
Expand Down
19 changes: 18 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,12 +815,29 @@ def _take_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy
register_grad(pids.TAKE, _take_prim_grad)


@torchctx
def _gather_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy:
fwd = prims.gather(a, index, dim)

g = get_grad(fwd)
# NOTE Intentionally not calling zeros_like to avoid preserving TensorProxy a.
# TODO Update to call ltorch.zeros
zeros = prims.full(a.shape, fill_value=0, device=a.device, dtype=a.dtype)
a_grad = prims.scatter_add(zeros, index, g, dim)
put_grad(a, a_grad)

return fwd


register_grad(pids.GATHER, _gather_prim_grad)


@torchctx
def _take_along_axis_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy:
fwd = prims.take_along_axis(a, index, dim)

g = get_grad(fwd)
# NOTE Intentionally not calling zeros_like to avoid preserving a
# NOTE Intentionally not calling zeros_like to avoid preserving TensorProxy a.
# TODO Update to call ltorch.zeros
zeros = prims.full(a.shape, fill_value=0, device=a.device, dtype=a.dtype)
a_grad = prims.scatter_add(zeros, index, g, dim)
Expand Down
13 changes: 13 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def _topk_transform(
# Scatter and gather operations
#

gather = _register_torch_operation("gather")
index_add = _register_torch_operation("index_add")
index_put = _register_torch_operation("index_put")
scatter_add = _register_torch_operation("scatter_add")
Expand All @@ -1117,6 +1118,16 @@ def _index_put_prim_transform(
return index_put(a, indices, values, accumulate)


@langctx(Languages.TORCH)
def _gather_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: int) -> TensorProxy:
return gather(a, dim, index)


@langctx(Languages.TORCH)
def _gather_transform(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike:
return gather(a, dim, index)


# NOTE torch.compile has a compilation issue with scatter add in bfloat16,
# hence the special case here.
# NOTE The scatter add transforms must set the torch language context explicitly so the .to() method
Expand Down Expand Up @@ -1152,6 +1163,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
return take_along_dim(a, index, dim)


_register_implementation(prims.gather, checker=_always_executable, execution_transform=_gather_prim_transform)
_register_implementation(prims.index_add, checker=_always_executable, execution_transform=_index_add_prim_transform)
_register_implementation(prims.index_put, checker=_always_executable, execution_transform=_index_put_prim_transform)
_register_implementation(prims.scatter_add, checker=_always_executable, execution_transform=_scatter_add_prim_transform)
Expand All @@ -1160,6 +1172,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
prims.take_along_axis, checker=_always_executable, execution_transform=_take_along_axis_prim_transform
)

_register_implementation(ltorch.gather, checker=_always_executable, execution_transform=_gather_transform)
_register_implementation(ltorch.index_add, index_add, checker=_always_executable)
_register_implementation(ltorch.index_put, index_put, checker=_always_executable)
_register_implementation(ltorch.index_select, index_select, checker=_always_executable)
Expand Down
36 changes: 36 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4421,6 +4421,42 @@ def take_along_axis_sample_generator(op, device, dtype, requires_grad, **kwargs)
shape_ops.append(take_along_axis_opinfo)


def gather_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# torch.gather expects index to be long but not int
# Index is not differentiable! Marking requires_grad as False
make_index = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)

for shape_a, dim, shape_b in take_along_axis_cases:
canonicalized_dim = dim if dim >= 0 else dim + len(shape_a)
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
yield SampleInput(a, index=b, dim=dim)

# Note that gather doesn't have the broadcast requirement, it only requires
# 1. a.shape[i] >= index.shape[i] for i != dim
#
# a.shape, dim, index.shape
scatter_add_cases = (
((4, 5, 3), 0, (3, 2, 3)),
((4, 5, 3), 1, (3, 5, 2)),
((4, 5, 3), 2, (3, 2, 8)),
)
for shape_a, dim, shape_b in scatter_add_cases:
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
yield SampleInput(a, index=b, dim=dim)


gather_opinfo = OpInfo(
ltorch.gather,
supports_grad=True,
sample_input_generator=gather_sample_generator,
torch_reference=torch.gather,
)
shape_ops.append(gather_opinfo)


def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# torch.scatter_add expects index to be long but not int
Expand Down
5 changes: 5 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,11 @@ def index_select(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike:
return clang.take(a, index, dim)


@torchsymbol(torch.gather)
def gather(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike:
return clang.gather(a, indices=index, dim=dim)


# NOTE PyTorch's scatter_add has a parameter named 'src', not 'source'
@torchsymbol(torch.scatter_add)
def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike:
Expand Down

0 comments on commit 34d21e1

Please sign in to comment.