Skip to content

Commit

Permalink
move cmp tuple out [run_process_replay] (tinygrad#6825)
Browse files Browse the repository at this point in the history
* move cmp tuple out [run_process_replay]

* was unneeded
  • Loading branch information
geohot authored Oct 1, 2024
1 parent a1dee0e commit 50dd6bd
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def test_simple_unroll_no_between_phi_dependencies(self):
def test_grouped_dims(self):
def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
loop_idxs = dedup(flatten([[y for y in x.sparents if y.op is UOps.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
sizes = [x.arg[1] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
Expand Down
2 changes: 2 additions & 0 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def test_bitshift_right(self):
self.assertEqual(uops[-2].arg, BinaryOps.IDIV)

class TestUOpMethod(unittest.TestCase):
@unittest.skip("uops lt no longer ordered")
def test_compare_alu_same_src_different_arg(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
b = UOp(UOps.CONST, dtypes.float, (), 3.0)
Expand Down Expand Up @@ -397,6 +398,7 @@ def test_vectorized_str(self):
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
assert str(eval(str(vec))) == str(vec)

@unittest.skip("uop no longer has order like this")
class TestIndexingOrdering(unittest.TestCase):
# NOTE: these tests skip type_verify since they add dtype to STORE
@unittest.expectedFailure
Expand Down
17 changes: 13 additions & 4 deletions tinygrad/codegen/uopgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable
import functools, itertools, heapq, math, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType, DType
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite, TernaryOps
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
Expand Down Expand Up @@ -802,16 +802,25 @@ def fix_priority(u:UOp, lowest_priority):
for x in u.src: fix_priority(x, priorities[u])
fix_priority(sink, 0)

queue:List[Tuple[int, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], u))
@functools.lru_cache(None)
def tuplize(u:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
if u.op is UOps.DEFINE_VAR: arg = u.arg[0]
elif u.op is UOps.ALU: arg = u.arg.value
else: arg = u.arg
return (u.op.value, arg, u.dtype, tuple(tuplize(x) for x in u.src))

# NOTE: the compare should never make it all the way to u
queue:List[Tuple[int, Tuple, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], tuplize(u), u))

for u in children:
if in_degree[u] == 0: push(u)

scope_end: Dict[UOp, UOp] = {}
_uops: List[UOp] = []
while queue:
p,x = heapq.heappop(queue)
p,_,x = heapq.heappop(queue)
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
if x in scope_children: scope_end[x] = x
if x.op is UOps.DEFINE_ACC:
Expand Down
8 changes: 0 additions & 8 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,6 @@ def st(self) -> Optional[ShapeTracker]:
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
@functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
if self.op is UOps.DEFINE_VAR: arg = self.arg[0]
elif self.op is UOps.ALU: arg = self.arg.value
else: arg = self.arg
return (self.op.value, arg, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
Expand Down

0 comments on commit 50dd6bd

Please sign in to comment.