Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move DEFINE_VAR min/max from src to arg [run_process_replay] #414

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions test/test_uop_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,17 @@ def test_depth_2_fold(self):
self.assertEqual(nout.src[1].arg, 3.0)

def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, constant_folder)
print(sink)
self.assertEqual(sink.op, UOps.ALU)
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)

class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
Expand All @@ -155,7 +155,7 @@ def test_add_constant_fold(self):
self.assertEqual(out.arg, 3.0)

def test_where_same_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp(UOps.CONST, dtypes.int, (), 0), UOp(UOps.CONST, dtypes.int, (), 1)), arg=Variable('tmp', 0, 1))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
Expand Down Expand Up @@ -249,7 +249,7 @@ def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
Expand All @@ -258,7 +258,7 @@ def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
Expand All @@ -268,37 +268,37 @@ def test_wmma_vectorize_no_fold(self):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)

for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)

for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)

for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
Expand All @@ -324,13 +324,13 @@ def test_double_cast_fold(self):
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)

def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
uops = to_uops_list([out])
self.assertEqual(len(uops), 5)
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)
Expand Down Expand Up @@ -575,8 +575,8 @@ def test_simple_load_fold_gated(self):

def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
Expand All @@ -591,7 +591,7 @@ def test_simple_store_fold(self):

def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
Expand All @@ -602,8 +602,8 @@ def test_simple_store_fold_gate(self):

def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_uop_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def Variable(expr, nmin, nmax):
# TODO: fix DEFINE_VAR to not need this
class TempVar:
def __init__(self, x): self.expr = x
return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
Expand Down
12 changes: 7 additions & 5 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def __hash__(self): return id(self)
@functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0].expr) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
@functools.cached_property
Expand Down Expand Up @@ -364,7 +364,7 @@ def const(cls, dtype:Optional[DType], b:ConstType|Variable): return cls._const(d
@classmethod
def _const(cls, dtype:Optional[DType], b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, (cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))), b)
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
Expand All @@ -386,7 +386,7 @@ def full_shape(self) -> Tuple[sint, ...]:
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
return sorted(set.union(*st_vars, set([x.arg[0] for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
Expand All @@ -412,7 +412,8 @@ def vmax(self) -> UOp:
@functools.cached_property
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
# TODO: fix DEFINE_VAR arg in tests and remove checking len(self.arg)
if self.op is UOps.DEFINE_VAR and self.arg and len(self.arg) > 1: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None
Expand Down Expand Up @@ -473,7 +474,8 @@ def truncate_fp16(x):
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))

def uop_alu_resolve(u:UOp) -> sint:
if u.op in {UOps.CONST, UOps.DEFINE_VAR}: return u.arg
if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return u.arg[0]
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __post_init__(self):
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg)
if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg[0])
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL:
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/renderer/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
r[u] = "%" + args[0]
kernel = [f".reg .u32 %{args[0]};"] + kernel
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
bufs.append((args[0].expr, dtype))
r[u] = f"%{args[0].expr}"
kk(*self.render_load(args[0].expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
elif uop is UOps.LOAD:
Expand Down
8 changes: 4 additions & 4 deletions tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ def ssa(prefix:str, u:Optional[UOp]=None):
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
assert args.expr not in seen_vars, f"duplicate variable {args.expr}"
seen_vars.add(args.expr)
bufs[u] = (args.expr, (dtype,False))
r[u] = args.expr
assert args[0].expr not in seen_vars, f"duplicate variable {args[0].expr}"
seen_vars.add(args[0].expr)
bufs[u] = (args[0].expr, (dtype,False))
r[u] = args[0].expr
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/shape/shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if is
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }

Expand Down
Loading