Skip to content

Commit

Permalink
Add slice parameter type checking to disallow Tensor usage for slices (
Browse files Browse the repository at this point in the history
…tinygrad#6967)

* add support for single el tensors for slices

* rm trailing spaces

* cleanup long lines

* remove tensor in slice support, add comprehensive err msg

* cleanup getitem, add slice type check

* Edit err message
  • Loading branch information
mnovosad1095 authored Oct 11, 2024
1 parent b0dd407 commit 8831c69
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,7 @@ def test_slice_errors(self):
with self.assertRaisesRegex(IndexError, "out of bounds"): a[1, -4]
with self.assertRaisesRegex(IndexError, "single ellipsis"): a[..., ...] # IndexError: only single ellipsis
with self.assertRaises(ValueError): a[::0, 1] # no 0 strides
with self.assertRaises(TypeError): a[:Tensor([3]), 1] # Tensor can't be used as a slice parameter
with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor

def test_slice_ellipsis(self):
Expand Down
3 changes: 3 additions & 0 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,9 @@ def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1)
for dim in type_dim[slice]:
if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step")
if not all(isinstance(x, (int, type(None))) for x in (index.start, index.stop, index.step)):
raise TypeError(f"Unsupported slice for dimension {dim}. Expected slice with integers or None, got slice("
f"{', '.join(type(x).__name__ for x in (index.start, index.stop, index.step))}).")
s, e, st = index.indices(self.shape[dim])
indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st)
# skip all Tensor dims for basic indexing
Expand Down

0 comments on commit 8831c69

Please sign in to comment.