Skip to content

Commit

Permalink
Tensor.roll touchup
Browse files Browse the repository at this point in the history
simplified a bit.
it might be able to write it with only movements, but the backward would contain a reduce.
  • Loading branch information
chenyuxyz committed Sep 6, 2024
1 parent 2e01efc commit 80b26ba
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,7 @@ def unflatten(self, dim:int, sizes:Tuple[int,...]):

def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ...]]) -> Tensor:
"""
Roll the tensor along specified dimension(s).
Rolls the tensor along specified dimension(s).
The rolling operation is circular, meaning that elements that go beyond the edge are wrapped around to the beginning of the dimension.
```python exec="true" source="above" session="tensor" result="python"
Expand All @@ -1331,13 +1331,11 @@ def roll(self, shifts:Union[int, Tuple[int, ...]], dims:Union[int, Tuple[int, ..
print(Tensor.rand(3, 4, 1).roll(shifts=-1, dims=0))
```
"""
dims, shifts = (dims,) if isinstance(dims, int) else dims, (shifts,) if isinstance(shifts, int) else shifts
dims = tuple(i % len(self.shape) for i in dims)
all_shifts = [shifts[dims.index(i)] % self.shape[i] if i in dims else 0 for i in range(len(self.shape))]
rolled = self
for i, shift in enumerate(all_shifts):
rolled = Tensor.cat(rolled[tuple(slice(None) if j != i else slice(-shift, None) for j in range(len(rolled.shape)))],
rolled[tuple(slice(None) if j != i else slice(None, -shift) for j in range(len(rolled.shape)))], dim=i)
dims, rolled = tuple(self._resolve_dim(d) for d in make_pair(dims, 1)), self
for dim, shift in zip(dims, make_pair(shifts, 1)):
shift = shift % self.shape[dim]
rolled = Tensor.cat(rolled[tuple(slice(None) if i != dim else slice(-shift, None) for i in range(rolled.ndim))],
rolled[tuple(slice(None) if i != dim else slice(None, -shift) for i in range(rolled.ndim))], dim=dim)
return rolled

# ***** reduce ops *****
Expand Down

0 comments on commit 80b26ba

Please sign in to comment.