Skip to content

Commit

Permalink
use more resolve in View merge add [pr] (tinygrad#7055)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Oct 14, 2024
1 parent 8428244 commit 0d2462c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __add__(self, vm1:View) -> Optional[View]:
if vm1.contiguous and vm1.size() == vm2.size() and (ret := vm2.reshape(vm1.shape)) is not None: return ret
if vm1.mask:
for b,e in vm1.mask:
if not (b < e): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
if resolve(b >= e, False): return View.create(vm1.shape, (0,) * len(vm1.shape), 0, ((0,0),) * len(vm1.shape))
return (merged := vm2 + vm1.shrink(vm1.mask)) and merged.pad(tuple((b,s-e) for (b,e),s in zip(vm1.mask, vm1.shape)))

# Project vm1's offset and strides on to vm2.
Expand Down Expand Up @@ -193,7 +193,7 @@ def __add__(self, vm1:View) -> Optional[View]:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
if not (t.vmin < b or t.vmax >= e): continue
if resolve(b <= t.vmin and t.vmax < e, False): continue
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
bad = True
continue
Expand Down

0 comments on commit 0d2462c

Please sign in to comment.