diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 6eba8993c8e4..842a861b33b4 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -76,8 +76,9 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple return tuple(reversed(new_mask)) -def un1d(shape:Tuple[sint, ...], offset:sint) -> List[sint]: +def unravel(shape:Tuple[sint, ...], offset:sint) -> List[sint]: # find the position of offset on each dimension based on shape + # similar to unravel_index in numpy/torch ret = [] for stride in strides_for_shape(shape): ret.append(offset // stride if stride != 0 else 0) @@ -171,12 +172,12 @@ def __add__(self, vm1:View) -> Optional[View]: if not all_int(vm1.shape): return None # Project vm1's offset and strides on to vm2. - origin = un1d(vm2.shape, vm1.offset) + origin = unravel(vm2.shape, vm1.offset) terms: List[List[Tuple[int, sint]]] = [[] for _ in vm2.shape] strides: List[sint] = [0] * len(vm1.shape) for d1, st in enumerate(vm1.strides): if st == 0: continue - for d2, (o, s1) in enumerate(zip(origin, un1d(vm2.shape, vm1.offset + st))): + for d2, (o, s1) in enumerate(zip(origin, unravel(vm2.shape, vm1.offset + st))): if (s1 := s1 - o) == 0: continue terms[d2].append((d1, s1)) strides[d1] += s1 * vm2.strides[d2]