Skip to content

Commit

Permalink
viz late to_json [pr] (tinygrad#7070)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalin authored Oct 15, 2024
1 parent 52d8afd commit 1a45e94
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
19 changes: 5 additions & 14 deletions test/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tinygrad.dtype import PtrDType, dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import _replace_uop, get_details, get_metadata
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json

@track_rewrites
def rewrite(sink:UOp, pm:PatternMatcher, ctx=None): return graph_rewrite(sink, pm, ctx)
Expand All @@ -12,14 +12,9 @@ def helper_test_viz(sink:UOp, pm:PatternMatcher, ctx=None) -> List[UOp]:
rewrite(sink, pm, ctx)
assert len(contexts) == 1
assert len(contexts[0][1]) == 1
ctx = contexts[0][1][0]
uops = [ctx.sink]
replaces: Dict[UOp, UOp] = {}
for u0,u1,_ in ctx.rewrites:
replaces[u0] = u1
new_sink = _replace_uop(uops[-1], {**replaces})
uops.append(new_sink)
return uops[1:]
k = get_metadata(contexts)[0][0]
g = get_details(*k)
return g.graphs[1:]

class TestViz(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -91,12 +86,8 @@ def do_rewrite(key:str, x:UOp):
self.assertEqual(len(ret), 1)

def test_fold_const(self):
pm = PatternMatcher([
(UPat.var("x")*1, lambda x:x),
])
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
rewrite(a, pm)
graph = get_details(*get_metadata(contexts)[0][0]).graphs[-1]
graph = uop_to_json(a)
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1

Expand Down
11 changes: 6 additions & 5 deletions tinygrad/viz/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class GraphRewriteMetadata:
@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
graphs: List[UOp]
"""Sink at every step of graph_rewrite"""
diffs: List[List[str]]
""".diff style before and after of the rewritten UOp child"""
Expand All @@ -52,7 +52,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())

def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
Expand All @@ -69,7 +69,7 @@ def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
@functools.lru_cache(None)
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k))
g = GraphRewriteDetails(**asdict(metadata), graphs=[ctx.sink], diffs=[], changed_nodes=[], kernel_code=_prg(k))
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
Expand All @@ -82,7 +82,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
g.graphs.append(_uop_to_json(sink:=new_sink))
g.graphs.append(sink:=new_sink)
return g

# ** HTTP server
Expand All @@ -101,7 +101,8 @@ def do_GET(self):
self.end_headers()
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode()
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
ret = json.dumps({**asdict(g), "graphs": list(map(uop_to_json, g.graphs))}).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
else:
self.send_response(404)
Expand Down

0 comments on commit 1a45e94

Please sign in to comment.