Skip to content

Commit

Permalink
symbolic bool raise ValueError when not sure [pr] (tinygrad#6853)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Oct 2, 2024
1 parent 08850da commit c3c93f3
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 26 deletions.
4 changes: 2 additions & 2 deletions extra/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, ma
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()

keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if isinstance(start_pos, Variable) or start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if isinstance(start_pos, Variable) or start_pos > 0 else xv

keys, values = repeat_kv(keys, self.n_rep), repeat_kv(values, self.n_rep)
xq, keys, values = xq.transpose(1, 2), keys.transpose(1, 2), values.transpose(1, 2)
Expand Down
9 changes: 4 additions & 5 deletions test/unit/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ def test_dedup(self):
class TestSymbolicMinMax(unittest.TestCase):
def test_min_max_known(self):
a = Variable("a", 1, 8)
assert max(1, a) == max(a, 1) == a
assert min(1, a) == min(a, 1) == 1
assert max(1, a, key=lambda x:x if isinstance(x, int) else x.max) == max(a, 1, key=lambda x:x if isinstance(x, int) else x.max) == a
assert min(1, a, key=lambda x:x if isinstance(x, int) else x.max) == min(a, 1, key=lambda x:x if isinstance(x, int) else x.max) == 1

class TestSymRender(unittest.TestCase):
def test_sym_render(self):
Expand Down Expand Up @@ -518,13 +518,12 @@ def test_node_lt_node(self):
assert create_lt_node(d, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
assert create_lt_node(a, a) == NumNode(0)
# if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
# if it remains as a LtNode, (min, max) == (0, 1)
a_lt_c = create_lt_node(a, c)
assert isinstance(a_lt_c, LtNode) and a_lt_c.min == 0 and a_lt_c.max == 1
assert a_lt_c
# same when comparing with a constant
a_lt_3 = create_lt_node(a, 3)
assert a_lt_3 and a_lt_3.min == 0 and a_lt_3.max == 1
assert a_lt_3.min == 0 and a_lt_3.max == 1

def test_sumnode_mulnode_lt(self):
a = Variable("a", 1, 2)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve

# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
Expand All @@ -555,7 +555,7 @@ def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve

# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast.
if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
if resolve((s:=self.full_unupcasted_shape[-1]) <= 32, False) and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
# if it's small, upcast a second reduce dimension too
if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite
from tinygrad.ops import REDUCE_ALU, MetaOps, ReduceOps, UNSAFE_PAD_OPS, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, graph_rewrite, resolve
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, AST_REWRITE, \
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
Expand Down Expand Up @@ -207,10 +207,10 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if buf is not buf.base:
# fuse some pads
if len(buf.st.views) == 1 and buf.st.views[-1].mask is not None and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
resolve(prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask])):
simple_pads[buf.base] = None
# realize all expands
elif prod(buf.base.st.shape) < prod(buf.st.shape):
elif resolve(prod(buf.base.st.shape) < prod(buf.st.shape)):
# this was causing "test_lil_model" to fail
if buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg, ImageDType):
simple_pads[buf.base] = None # don't realize image to image casts. this is part of a larger problem
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x, key=lambda x: x if isinstance(x, int) else x.max) \
for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
Expand Down
23 changes: 15 additions & 8 deletions tinygrad/shape/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import functools
from math import gcd
from tinygrad.helpers import partition
from tinygrad.helpers import partition, all_int
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping

# NOTE: Python has different behavior for negative mod and floor div than c
Expand All @@ -25,7 +25,9 @@ def key(self) -> str: return self.render(ctx="DEBUG")
def __repr__(self): return self.render(ctx="REPR")
def __str__(self): return "<"+self.key+">"
def __hash__(self): return hash(self.key)
def __bool__(self): return not (self.max == self.min == 0)
def __bool__(self):
if self.max == self.min: return self.max == 1
raise ValueError(f"couldn't resolve boolean expression {self}")
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
Expand All @@ -52,7 +54,8 @@ def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
if isinstance(b, Node):
if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
if self == b: return NumNode(1)
if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
# if isinstance(m:=(b-self).min, int) and m > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
if isinstance(m:=(b-self.max).min, int) and m > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
raise RuntimeError(f"not supported: {self} // {b}")
assert b != 0
if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
Expand All @@ -70,7 +73,8 @@ def __mod__(self, b:Union[Node,int]):
if isinstance(b, Node):
if b.__class__ is NumNode: return self % b.b
if self == b: return NumNode(0)
if (b - self).min > 0 and self.min >= 0: return self # b - self simplifies the node
# if isinstance(m:=(b-self).min, int) and m > 0 and self.min >= 0: return self # b - self simplifies the node
if isinstance(m:=(b-self.max).min, int) and m > 0 and self.min >= 0: return self # b - self simplifies the node
raise RuntimeError(f"not supported: {self} % {b}")
assert b > 0
if b == 1: return NumNode(0)
Expand All @@ -81,7 +85,7 @@ def __mod__(self, b:Union[Node,int]):

@staticmethod
def sum(nodes:List[Node]) -> Node:
nodes = [x for x in nodes if x.max or x.min]
nodes = [x for x in nodes if not (x.max==x.min==0)]
if not nodes: return NumNode(0)
if len(nodes) == 1: return nodes[0]

Expand All @@ -99,7 +103,7 @@ def sum(nodes:List[Node]) -> Node:
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any(not x for x in nodes): return NumNode(0)
if any(x.max==0 for x in nodes): return NumNode(0)

# filter 1s
nodes = [x for x in nodes if x.min != x.max]
Expand Down Expand Up @@ -190,7 +194,9 @@ class LtNode(OpNode):
def get_bounds(self) -> Tuple[int, int]:
if self.a == self.b: return (0, 0)
if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
if all_int([self.a.max, self.b.min]) and self.a.max < self.b.min: return (1, 1)
if all_int([self.a.min, self.b.max]) and self.a.min >= self.b.max: return (0, 0)
return (0, 1)
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))

Expand Down Expand Up @@ -223,7 +229,8 @@ def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
def get_bounds(self) -> Tuple[int, sint]:
assert self.a.min >= 0 and isinstance(self.b, int)
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
if all_int([self.a.max, self.a.min, self.b]):
if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
return (self.a.min%self.b, self.a.max%self.b)
def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b

Expand Down
10 changes: 7 additions & 3 deletions tinygrad/shape/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,20 @@ def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View

@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def pad(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
assert all((b>=0 and e>=0) for b,e in arg) and len(arg) == len(self.shape), f"{self.shape=}, {arg=}"
if any(b>0 or e>0 for b, e in arg):
assert len(arg) == len(self.shape), f"invalid pad {arg} for {self.shape}"
# NOTE: not checking for symbolic arg
for b,e in arg: assert not all_int([b,e]) or b>=0 and e>=0, f"invalid pad {arg} for {self.shape}"
if any(b!=0 or e!=0 for b, e in arg):
zvarg = tuple([(-b,s+e) for s,(b,e) in zip(self.shape, arg)])
mask = tuple([(b,s+b) for s,(b,_) in zip(self.shape, arg)])
return self.__unsafe_resize(zvarg, mask=mask)
return self

@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View:
assert all((0<=b<=e<=s) for s,(b,e) in zip(self.shape,arg)) and len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
assert len(arg) == len(self.shape), f"invalid shrink {arg} for {self.shape}"
# NOTE: not checking for symbolic arg
for s,(b,e) in zip(self.shape,arg): assert not all_int([s,b,e]) or (0<=b<=e<=s), f"invalid shrink {arg} for {self.shape}"
return self.__unsafe_resize(arg)

@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else max(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
return tuple(0 if 0 in nth_dim_sizes else max(nth_dim_sizes, key=lambda x: x if isinstance(x, int) else x.max) \
for nth_dim_sizes in zip(*_pad_left(*shapes)))

ReductionStr = Literal["mean", "sum", "none"]

Expand Down Expand Up @@ -1618,7 +1619,7 @@ def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, corr
"""
squares = (self - self.mean(axis=axis, keepdim=True)).square()
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction))
return squares.sum(axis=axis, keepdim=keepdim).div(max(0, n-correction, key=lambda x:x if isinstance(x, int) else x.max))

def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
"""
Expand Down

0 comments on commit c3c93f3

Please sign in to comment.