Skip to content

Commit

Permalink
handle arbitrary multiplication case (tinygrad#7019)
Browse files Browse the repository at this point in the history
* handle arbitrary multiplication case

* remove count restriction
  • Loading branch information
geohot authored Oct 12, 2024
1 parent 23faeac commit 5ce224c
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,20 @@ def _min_max(self) -> Tuple[ConstType, ConstType]:
if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
if self.op is UOps.CONST: return self.arg, self.arg
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
if self.op is UOps.ALU and self.dtype.count == 1:
if self.op is UOps.ALU:
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.arg is BinaryOps.MUL:
# both are non-positive
if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin
# at lease one is non-negative
# at least one is non-negative
if (s0.vmin >= 0 or s1.vmin >= 0):
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
return Lmin*Rmin, Lmax*Rmax
# arbitrary
products = [s0.vmin * s1.vmin, s0.vmin * s1.vmax, s0.vmax * s1.vmin, s0.vmax * s1.vmax]
return min(products), max(products)
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
Expand Down

0 comments on commit 5ce224c

Please sign in to comment.