Skip to content

Commit

Permalink
un1d -> unravel [pr] (tinygrad#8300)
Browse files Browse the repository at this point in the history
numpy/torch has a similar function called `unravel_index`
  • Loading branch information
chenyuxyz authored Dec 17, 2024
1 parent 66b92b6 commit a9f46eb
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple

return tuple(reversed(new_mask))

def un1d(shape:Tuple[sint, ...], offset:sint) -> List[sint]:
def unravel(shape:Tuple[sint, ...], offset:sint) -> List[sint]:
# find the position of offset on each dimension based on shape
# similar to unravel_index in numpy/torch
ret = []
for stride in strides_for_shape(shape):
ret.append(offset // stride if stride != 0 else 0)
Expand Down Expand Up @@ -171,12 +172,12 @@ def __add__(self, vm1:View) -> Optional[View]:
if not all_int(vm1.shape): return None

# Project vm1's offset and strides on to vm2.
origin = un1d(vm2.shape, vm1.offset)
origin = unravel(vm2.shape, vm1.offset)
terms: List[List[Tuple[int, sint]]] = [[] for _ in vm2.shape]
strides: List[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))):
for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))):
if (s1 := s1 - o) == 0: continue
terms[d2].append((d1, s1))
strides[d1] += s1 * vm2.strides[d2]
Expand Down

0 comments on commit a9f46eb

Please sign in to comment.