From 713461129b04afcd7e6cb8850fc5c962056ad287 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 21 Oct 2024 00:47:58 +0300 Subject: [PATCH] scheduler ast rewrite reorders from big graph [pr] (#7168) * scheduler ast rewrite reorders from big graph [pr] * update test_uops.py --- test/test_uops.py | 4 ++-- tinygrad/engine/schedule.py | 28 ++++++++++++++++------------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index 9cf3053290386..e1417bca4c79d 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Device from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program -from tinygrad.engine.schedule import create_schedule, enumerate_bufs +from tinygrad.engine.schedule import create_schedule, append_bufs from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym @@ -444,7 +444,7 @@ def test_simple_order_with_special(self): class TestUPatHelpers(unittest.TestCase): def test_location(self): self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py") - self.assertEqual(enumerate_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py") + self.assertEqual(append_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py") self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py") with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? test_upat = UPat(UOps.CONST, dtypes.bool) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 44594f224664d..161cdbf718293 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,7 +1,7 @@ import sys, atexit from collections import defaultdict, deque from dataclasses import dataclass -from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast +from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \ graph_rewrite, track_rewrites, sint from tinygrad.helpers import DEBUG, FUSE_CONV_BW, FUSE_ARANGE, Metadata, all_same, colored, diskcache_put, prod, dedup, all_int, merge_dicts, \ @@ -124,31 +124,35 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -def simplify_and_unbind(ctx, x:UOp) -> Optional[UOp]: - if (st:=unwrap(x.st)) in ctx[2]: return None +def _append_st_vars(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> Optional[UOp]: + if (st:=unwrap(x.st)) in ctx[1]: return None st, var_vals = st.simplify().unbind() ctx[0].update(var_vals) - ctx[2].add(st) + ctx[1].add(st) return st.to_uop() if st != x.st else None -append_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), simplify_and_unbind)]) +append_st_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), _append_st_vars)]) + +def _append_buf(ctx:Tuple[Dict[Variable, int], Set[ShapeTracker], Tuple[int, ...]], x:UOp) -> UOp: + return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[2].index(x.arg[0])) +append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf),]) + to_ast = PatternMatcher([ (UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), (UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda x: x.replace(src=())), ]) -enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[1].index(x.arg[0]))),]) PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = [] -if getenv("RUN_PROCESS_REPLAY"): - @atexit.register - def save_process_replay(): - for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret)) - def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable, int]) -> UOp: sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right) - ret = graph_rewrite(graph_rewrite(sink, to_ast), append_vars+enumerate_bufs, (var_vals, bufs, set())) + ret = graph_rewrite(graph_rewrite(sink, to_ast), append_st_vars+append_bufs, (var_vals, set(), bufs)) PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret)) return ret +if getenv("RUN_PROCESS_REPLAY"): + @atexit.register + def save_process_replay(): + for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret)) + # *** List[LazyBuffer] lowering to ScheduleItem *** def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp: