Skip to content

Commit

Permalink
scheduler ast rewrite reorders from big graph [pr] (tinygrad#7168)
Browse files Browse the repository at this point in the history
* scheduler ast rewrite reorders from big graph [pr]

* update test_uops.py
  • Loading branch information
Qazalin authored Oct 20, 2024
1 parent 8134921 commit 7134611
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
4 changes: 2 additions & 2 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, enumerate_bufs
from tinygrad.engine.schedule import create_schedule, append_bufs
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
Expand Down Expand Up @@ -444,7 +444,7 @@ def test_simple_order_with_special(self):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
self.assertEqual(enumerate_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(append_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool)
Expand Down
28 changes: 16 additions & 12 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, FUSE_CONV_BW, FUSE_ARANGE, Metadata, all_same, colored, diskcache_put, prod, dedup, all_int, merge_dicts, \
Expand Down Expand Up @@ -124,31 +124,35 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])

def simplify_and_unbind(ctx, x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx[2]: return None
def _append_st_vars(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> Optional[UOp]:
if (st:=unwrap(x.st)) in ctx[1]: return None
st, var_vals = st.simplify().unbind()
ctx[0].update(var_vals)
ctx[2].add(st)
ctx[1].add(st)
return st.to_uop() if st != x.st else None
append_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), simplify_and_unbind)])
append_st_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), _append_st_vars)])

def _append_buf(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> UOp:
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[2].index(x.arg[0]))
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf),])

to_ast = PatternMatcher([
(UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda x: x.replace(src=())),
])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[1].index(x.arg[0]))),])

PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
if getenv("RUN_PROCESS_REPLAY"):
@atexit.register
def save_process_replay():
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))

def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable, int]) -> UOp:
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
ret = graph_rewrite(graph_rewrite(sink, to_ast), append_vars+enumerate_bufs, (var_vals, bufs, set()))
ret = graph_rewrite(graph_rewrite(sink, to_ast), append_st_vars+append_bufs, (var_vals, set(), bufs))
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
return ret

if getenv("RUN_PROCESS_REPLAY"):
@atexit.register
def save_process_replay():
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))

# *** List[LazyBuffer] lowering to ScheduleItem ***

def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp:
Expand Down

0 comments on commit 7134611

Please sign in to comment.