From 3c15e64273272c4085e4a50d967a61c3a6412f2c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:06:31 +0800 Subject: [PATCH] VIZ prep for the new kernel render (#6800) * refactor to list * remove prints in test_viz * more cleanup --- tinygrad/ops.py | 14 +++++++------- viz/serve.py | 22 ++++++++++------------ viz/test_viz.py | 11 ++++++----- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bae3d00b6e295..6cf968e839b68 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -533,13 +533,13 @@ def print_match_stats(): if getenv("VIZ"): os.environ["VIZ"] = "0" os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "..", "viz", "serve.py")]) - ret = [0,0,0.0,0.0] - for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): - loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" - if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) - ret = [x+y for x,y in zip(ret, v)] - print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL") - + if getenv("PRINT_MATCH_STATS", 1): + ret = [0,0,0.0,0.0] + for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): + loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" + if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) + ret = [x+y for x,y in zip(ret, v)] + print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL") # *** simple graph rewrite engine *** diff --git a/viz/serve.py b/viz/serve.py index 63a6a4e0a2a36..bfd1b7dbee461 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -15,9 +15,8 @@ # **** /graph - detailed UOp + rewrites # NOTE: UPats in ops.py are spec -# TODO: fix key for uop with buffer def graph_rewrites(ctx:TrackedRewriteContext): - return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))] + return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"] @dataclass(frozen=True) class RewriteLocation: @@ -61,13 +60,13 @@ def from_ctx(ctx:TrackedRewriteContext) -> UOpRet: extra.append([str(new_sink)]) return UOpRet(RewriteLocation.from_ctx(ctx), uops, diffs, extra, additions) def to_json(self) -> Dict: - return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(lambda x:uop_to_json(x, self.graphs[0]), self.graphs))} + return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(uop_to_json, self.graphs))} -def uop_to_json(x:UOp, base:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: +def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} for u in x.sparents: - if u.op is UOps.CONST and u is not base: continue + if u.op is UOps.CONST: continue label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" for idx,x in enumerate(u.src): if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}" @@ -89,9 +88,8 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: class KernelRet: name: str code: str - ctxs: Dict[Tuple[Tuple, bytes], TrackedRewriteContext] - def to_json(self) -> Dict: - return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs.values()]} + ctxs: List[TrackedRewriteContext] + def to_json(self) -> Dict: return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs]} def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]: ret: Dict[str, KernelRet] = {} @@ -102,8 +100,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]: si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.BUFFER)) with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, "" - if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {}) - ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx + if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, []) + ret[k].ctxs.append(ctx) return list(ret.values()) class Handler(BaseHTTPRequestHandler): @@ -131,8 +129,8 @@ def do_GET(self): self.send_header("Content-type", "application/json") self.end_headers() k = kernels[int(query["kernel_idx"][0])] - g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])]) - ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode() + g = UOpRet.from_ctx(k.ctxs[int(query["uop_idx"][0])]) + ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs])).encode() else: self.send_response(404) ret = b"" diff --git a/viz/test_viz.py b/viz/test_viz.py index b030b30bd5af5..0648c6998998a 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -1,6 +1,7 @@ import unittest import os, itertools os.environ["TRACK_MATCH_STATS"] = "2" +os.environ["PRINT_MATCH_STATS"] = "0" from extra.models.resnet import ResNet50 from tinygrad import Tensor from tinygrad.engine.realize import lower_schedule @@ -11,7 +12,7 @@ from test.external.process_replay.helpers import print_diff from viz.serve import KernelRet, UOpRet, load_kernels, uop_to_json -def group_rewrites(kernels:KernelRet): return {k:list(v) for k,v in itertools.groupby(kernels.ctxs.values(), lambda x:x.loc)} +def group_rewrites(kernels:KernelRet): return {k:list(v) for k,v in itertools.groupby(kernels.ctxs, lambda x:x.loc)} class TestViz(unittest.TestCase): def tearDown(self) -> None: @@ -48,8 +49,8 @@ def test_ctx_groups(self): list(lower_schedule(schedule2)) ret = load_kernels(contexts) assert len(ret) == 2 - assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc[0]]) != 0 for y in ret) - assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc[0]]) != 0 for y in ret) + assert all(len([x for x in y.ctxs if "schedule" in x.loc[0]]) != 0 for y in ret) + assert all(len([x for x in y.ctxs if "uopgraph" in x.loc[0]]) != 0 for y in ret) def test_gemm_diff(self): x = Tensor.empty(64, 64).realize() @@ -140,11 +141,11 @@ def test_fold_const_nodes(self): a = Tensor.empty(4, 4)+2 contexts.clear() sink = a.schedule()[-1].ast - ret = uop_to_json(sink, base=sink) - for v in ret.values(): print(v) + ret = uop_to_json(sink) assert not any(v[0].startswith("CONST") for v in ret.values()) assert len([x for x in ret.values() if "CONST" in x[0]]) == 1 + @unittest.skip("VIZ for a single CONST isn't supported anymore") def test_no_fold_single_const(self): node = UOp(UOps.CONST, dtypes.float, (), 1.0) ret = uop_to_json(node, base=node)