Skip to content

Commit

Permalink
minor changes to views add [pr]
Browse files Browse the repository at this point in the history
naming / style / comments before logic change
  • Loading branch information
chenyuxyz committed Dec 17, 2024
1 parent e373176 commit b0a4781
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ def _reshape_mask(_mask:Optional[Tuple[Tuple[sint, sint], ...]], old_shape:Tuple

return tuple(reversed(new_mask))

def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
result = []
def un1d(shape:Tuple[sint, ...], offset:sint) -> List[sint]:
# find the position of offset on each dimension based on shape
ret = []
for stride in strides_for_shape(shape):
here = offs // stride if stride != 0 else 0
result.append(here)
offs -= here * stride
return result
ret.append(offset // stride if stride != 0 else 0)
offset -= ret[-1] * stride
return ret

@dataclass(frozen=True)
class View:
Expand Down Expand Up @@ -123,7 +123,7 @@ def create(shape:Tuple[sint, ...], strides:Optional[Tuple[sint, ...]]=None, offs
strides = canonicalize_strides(shape, strides) if strides else strides_for_shape(shape)
# canonicalize 0 in shape
if 0 in shape: return View(shape, (0,) * len(shape), offset=0, mask=None, contiguous=True)
# canonicalize empty mask
# canonicalize no-op mask
if mask is not None and all(m == (0,s) for m,s in zip(mask, shape)): mask = None
# if any dimension has size >1, but is masked such that only one index in the dimension is unmasked
# then its stride can also be set to 0, albeit with a corresponding adjustment required to the offset
Expand Down Expand Up @@ -167,12 +167,12 @@ def __add__(self, vm1:View) -> Optional[View]:
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
if vm1.mask:
for b,e in vm1.mask:
if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
if not resolve(b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))

# Project vm1's offset and strides on to vm2.
origin = un1d(vm2.shape, vm1.offset)
terms: List[List[Tuple[int, sint]]] = [[] for _ in origin]
terms: List[List[Tuple[int, sint]]] = [[] for _ in vm2.shape]
strides: List[sint] = [0] * len(vm1.shape)
for d1, st in enumerate(vm1.strides):
if st == 0: continue
Expand All @@ -195,8 +195,7 @@ def __add__(self, vm1:View) -> Optional[View]:
merged_size, merged_term = 1, UOp.const(dtypes.int, 0)
if resolve(merged_term != 0): return None
if (vm2_shape := tuple(s for s,_ in reversed(extents))) != vm2.shape:
reshaped_vm2 = vm2.reshape(vm2_shape)
if reshaped_vm2 is None: return None
if (reshaped_vm2 := vm2.reshape(vm2_shape)) is None: return None
if reshaped_vm2.shape != vm2.shape: return reshaped_vm2 + vm1

if vm2.mask:
Expand All @@ -212,7 +211,7 @@ def __add__(self, vm1:View) -> Optional[View]:
else: bad = True
continue
d1, s1 = term[0]
if not isinstance(s1, int) or not isinstance(newe[d1], int):
if not all_int([s1, newe[d1]]):
bad = True
continue
newb[d1] = max(newb[d1], math.ceil((b - o if s1 > 0 else e - o - 1) / s1))
Expand Down

0 comments on commit b0a4781

Please sign in to comment.