diff --git a/test/test_viz.py b/test/test_viz.py index bfac68b08481..aa4e6e3a6101 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -3,7 +3,7 @@ from tinygrad.dtype import PtrDType, dtypes from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \ graph_rewrite, contexts, track_rewrites -from tinygrad.viz.serve import _replace_uop, get_details, get_metadata +from tinygrad.viz.serve import get_details, get_metadata, uop_to_json @track_rewrites def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx) @@ -12,14 +12,9 @@ def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]: rewrite(sink, pm, ctx) assert len(contexts) == 1 assert len(contexts[0][1]) == 1 - ctx = contexts[0][1][0] - uops = [ctx.sink] - replaces: Dict[UOp, UOp] = {} - for u0,u1,_ in ctx.rewrites: - replaces[u0] = u1 - new_sink = _replace_uop(uops[-1], {**replaces}) - uops.append(new_sink) - return uops[1:] + k = get_metadata(contexts)[0][0] + g = get_details(*k) + return g.graphs[1:] class TestViz(unittest.TestCase): def setUp(self): @@ -91,12 +86,8 @@ def do_rewrite(key:str, x:UOp): self.assertEqual(len(ret), 1) def test_fold_const(self): - pm = PatternMatcher([ - (UPat.var("x")*1, lambda x:x), - ]) a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) - rewrite(a, pm) - graph = get_details(*get_metadata(contexts)[0][0]).graphs[-1] + graph = uop_to_json(a) assert not any(v[0].startswith("CONST") for v in graph.values()) assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 4ee84f4f7409..4b73b6c2ce03 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -30,7 +30,7 @@ class GraphRewriteMetadata: @dataclass class GraphRewriteDetails(GraphRewriteMetadata): """Full details about a single call to graph_rewrite""" - graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]] + graphs: List[UOp] """Sink at every step of graph_rewrite""" diffs: List[List[str]] """.diff style before and after of the rewritten UOp child""" @@ -52,7 +52,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) return list(kernels.values()) -def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: +def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} for u in x.sparents: @@ -69,7 +69,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: @functools.lru_cache(None) def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: - g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k)) + g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k)) replaces: Dict[UOp, UOp] = {} sink = ctx.sink for i,(u0,u1,upat) in enumerate(ctx.rewrites): @@ -82,7 +82,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) # update ret data g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST]) g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines()))) - g.graphs.append(_uop_to_json(sink:=new_sink)) + g.graphs.append(sink:=new_sink) return g # ** HTTP server @@ -101,7 +101,8 @@ def do_GET(self): self.end_headers() query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: - ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode() + g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]) + ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs))}).encode() else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode() else: self.send_response(404)