Skip to content

Commit

Permalink
add Tensor.scatter (tinygrad#7737)
Browse files Browse the repository at this point in the history
* working I think

* where are my onnx scatter tests??

* forward_only for now

* try if nan hack fix NV

* looks like issue is different... CUDA WHY

* oops that was wrong. Try if this fixes CUDA

* simpler multiply

* actually finish this up tmrw morning :x

* fix tests?

* improve tests

* improve test and implementation

* fix ruff

* complete but lots of expected failure...

* reviewed tests

* add onnx tests

* is this a processing op?

* add return type to indicate that it's not in-place

* final cleanups

* use or and improve tests a little

* add masked_index_select

* call it masked_setitem instead

* try

* FIXED

---------

Co-authored-by: chenyu <[email protected]>
  • Loading branch information
geohotstan and chenyuxyz authored Nov 27, 2024
1 parent 38f34ca commit cea5853
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/tensor/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
::: tinygrad.Tensor.triu
::: tinygrad.Tensor.tril
::: tinygrad.Tensor.interpolate
::: tinygrad.Tensor.scatter

## Neural Network (functional)

Expand Down
5 changes: 5 additions & 0 deletions extra/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,12 @@ def Gather(x: Tensor, indices: Tensor, axis=0):
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated

def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Optional[str]=None):
if reduction in {"min", "max"}: raise NotImplementedError("min and max reduction not supported")
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, reduction)
def GatherElements(x: Tensor, indices: Tensor, axis):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
Expand Down
4 changes: 3 additions & 1 deletion test/external/external_test_onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def supports_device(cls, device: str) -> bool:
backend_test.exclude('test_basic_deform_conv_*')
backend_test.exclude('test_deform_conv_*')
backend_test.exclude('test_lppool_*')
backend_test.exclude('test_scan*')
backend_test.exclude('test_scan_*')
backend_test.exclude('test_split_to_sequence_*')
backend_test.exclude('test_resize_downsample_scales_cubic_*') # unsure how to implement cubic
backend_test.exclude('test_resize_downsample_sizes_cubic_*') # unsure how to implement cubic
Expand All @@ -157,6 +157,8 @@ def supports_device(cls, device: str) -> bool:
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_cpu') # bad data type string
backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad data type string
backend_test.exclude('test_group_normalization_*') # numerical inaccuracy problem. Current Group Normalization OP fails test
backend_test.exclude('test_scatter_elements_with_reduction_min_cpu') # min not yet supported
backend_test.exclude('test_scatter_elements_with_reduction_max_cpu') # max not yet supported

if Device.DEFAULT in ['GPU', 'METAL']:
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
Expand Down
66 changes: 66 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,6 +2316,72 @@ def test_gather(self):
lambda x: x.gather(dim=0, index=Tensor([2, 1, 0, 1, 2])),
vals=[[1., 2., 3.]])

def test_scatter(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src),
lambda x,src: x.scatter(dim=dim, index=a, src=src), forward_only=True)

helper_test_op([(3,4,5), (3,4,5)], lambda x,src: x.scatter(dim=1, index=b, src=src),
lambda x,src: x.scatter(dim=1, index=a, src=src), forward_only=True)
helper_test_op([(10,3,10), (10,10,10)], lambda x,src: x.scatter(dim=1, index=b, src=src),
lambda x,src: x.scatter(dim=1, index=a, src=src), forward_only=True)
self.helper_test_exception([(2,3,10), (10,10,10)], lambda x,src: x.scatter(dim=1, index=b, src=src),
lambda x,src: x.scatter(dim=1, index=a, src=src), expected=(RuntimeError, AssertionError))
self.helper_test_exception([(10,3,10), (10,3,10)], lambda x,src: x.scatter(dim=1, index=b, src=src),
lambda x,src: x.scatter(dim=1, index=a, src=src), expected=(RuntimeError, AssertionError))
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=3), lambda x: x.scatter(dim=1, index=a, src=3), forward_only=True)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf")),
lambda x: x.scatter(dim=1, index=a, src=float("inf")), forward_only=True)

# overlapping indices with 0s
b = torch.tensor([0,0], requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op(None,
lambda x,src: x.scatter(0, b, src),
lambda x,src: x.scatter(0, a, src), forward_only=True,
vals=[[1.,2.,3.,4.], [1.,0.]])

def test_scatter_add(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="add"),
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="add"), forward_only=True)

b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="add"),
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="add"), forward_only=True)

# TODO: fails for webgpu
if Device.DEFAULT != "WEBGPU":
helper_test_op([(4,5,6)],
lambda x: x.scatter(1, b, float("nan"), reduce="add"),
lambda x: x.scatter(1, a, float("nan"), reduce="add"), forward_only=True,)

def test_scatter_mul(self):
b = torch.randint(3, size=[3,4,5], dtype=torch.int64, requires_grad=False)
a = Tensor(b.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False)
for dim in (0,1,2,-1,-2,-3):
helper_test_op([(4,5,6), (4,5,6)], lambda x,src: x.scatter(dim=dim, index=b, src=src, reduce="multiply"),
lambda x,src: x.scatter(dim=dim, index=a, src=src, reduce="multiply"), forward_only=True)

helper_test_op([(4,5,6)], lambda x: x.scatter(dim=1, index=b, value=float("inf"), reduce="multiply"),
lambda x: x.scatter(dim=1, index=a, src=float("inf"), reduce="multiply"), forward_only=True)

# TODO: fails for webgpu
if Device.DEFAULT != "WEBGPU":
helper_test_op([(4,5,6)],
lambda x: x.scatter(1, b, float("nan"), reduce="multiply"),
lambda x: x.scatter(1, a, float("nan"), reduce="multiply"), forward_only=True,)

x = Tensor.zeros([4,5,6]).float()
y = torch.zeros([4,5,6]).float()
helper_test_op([(4,5,6)], lambda src: y.scatter(dim=1, index=b, src=src, reduce="multiply"),
lambda src: x.scatter(dim=1, index=a, src=src, reduce="multiply"), forward_only=True)

def test_scaled_product_attention(self):
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention)
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)],
Expand Down
48 changes: 39 additions & 9 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))

def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:Tuple[int, ...]):
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
values = values * mask
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
# remove extra dims from reduce
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
# select from values for each True element in mask else select from self
return mask.where(values, target)

ReductionStr = Literal["mean", "sum", "none"]

class Tensor(SimpleMathTrait):
Expand Down Expand Up @@ -1199,15 +1208,8 @@ def calc_dim(tensor_dim:int) -> int:
vb = v.cast(self.dtype)._broadcast_to(_broadcast_shape(ret.shape, v.shape))
# add back reduced dims from sum
for dim in sum_axis: vb = vb.unsqueeze(dim)
# axis to be reduced to match self.shape
axis = tuple(range(first_dim, first_dim + len(big_shape)))
# apply mask to vb(broadcasted) and reduce such that if mask contains repeated indices the last one remains
vb = vb * mask
for dim in axis: mask, vb = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), vb.split(1, dim)))
# remove extra dims from reduce
for dim in reversed(axis): mask, vb = mask.squeeze(dim), vb.squeeze(dim)
# select from vb for each True element in mask else select from self
ret = mask.where(vb, self)
# run _masked_setitem on tuple of axis that is to be reduced to match self.shape
ret = _masked_setitem(self, vb, mask, tuple(range(first_dim, first_dim + len(big_shape))))

return ret

Expand Down Expand Up @@ -2331,6 +2333,34 @@ def interpolate(self, size:Tuple[int, ...], mode:str="linear", align_corners:boo
x = x.gather(i, index)
return x.cast(self.dtype)

def scatter(self, dim:int, index:Tensor, src:Union[Tensor, ConstType], reduce:Union[None, Literal['multiply'], Literal['add']] = None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=9).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.scatter(dim=1, index=Tensor([[0, 0], [1, 0]]), src=Tensor([[3, 3], [9, 9]]), reduce="add").numpy())
```
"""
index, dim = index.to(self.device), self._resolve_dim(dim)
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or se >= ind) and sr >= ind for d,(se,ind,sr) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
mask = (index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)).transpose(-1, dim)
src = src.unsqueeze(-1).expand((None,)*src.ndim + (self.shape[dim],)).transpose(-1, dim).shrink(tuple((0,s) for s in mask.shape))
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
return _masked_setitem(self, src, mask, (-1,))

# ***** unary ops *****

def logical_not(self):
Expand Down

0 comments on commit cea5853

Please sign in to comment.