Skip to content

Commit

Permalink
prep for viz move to core [pr] (tinygrad#6938)
Browse files Browse the repository at this point in the history
* prep for viz move to core [pr]

* polish
  • Loading branch information
Qazalin authored Oct 7, 2024
1 parent e4c0743 commit 0ecc417
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 105 deletions.
133 changes: 75 additions & 58 deletions viz/serve.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,87 @@
#!/usr/bin/env python3
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib, multiprocessing, functools
from dataclasses import asdict
from urllib.parse import parse_qs, urlparse
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.codegen.kernel import Kernel
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from tinygrad.helpers import getenv, to_function_name, tqdm
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.engine.graph import uops_colors, word_wrap
from viz.spec import GraphRewriteDetails, GraphRewriteMetadata
from tinygrad.engine.graph import word_wrap, uops_colors
from tinygrad.codegen.kernel import Kernel

def reconstruct_graph(ctx:TrackedRewriteContext) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
uops: List[UOp] = [ctx.sink]
diffs: List[List[str]] = []
changed_nodes: List[List[int]] = []
seen_replaces: Dict[UOp, UOp] = {}
for i, (first, rewritten, upat) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
if new_sink is uops[-1]:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
# update ret data
changed_nodes.append([id(x) for x in rewritten.sparents if x.op is not UOps.CONST])
diffs.append(list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))
uops.append(new_sink)
return uops, diffs, changed_nodes
# ** API spec

@dataclass
class GraphRewriteMetadata:
"""Specifies metadata about a single call to graph_rewrite"""
loc: Tuple[str, int]
"""File_path, Lineno"""
code_line: str
"""The Python line calling graph_rewrite"""
kernel_name: Optional[str]
"""The kernel calling graph_rewrite"""
upats: List[Tuple[Tuple[str, int], str]]
"""List of all the applied UPats"""

@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
graphs: List[Dict[int, Tuple[str, str, List[int], str, str]]]
"""Sink at every step of graph_rewrite"""
diffs: List[List[str]]
""".diff style before and after of the rewritten UOp child"""
changed_nodes: List[List[int]]
"""Nodes that changed at every step of graph_rewrite"""
kernel_code: Optional[str]
"""The program after all rewrites"""

def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
# ** API functions

def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List[List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]]:
kernels: Dict[Optional[str], List[Tuple[Any, TrackedRewriteContext, GraphRewriteMetadata]]] = {}
for k,ctxs in contexts:
name = to_function_name(k.name) if isinstance(k, Kernel) else None
for ctx in ctxs:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
if name not in kernels: kernels[name] = []
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())

def _uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
if u.op is UOps.CONST: continue
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
if getenv("WITH_SHAPE"):
with contextlib.suppress(Exception): # if the UOp is indexed already it's fine
if u.st is not None: label += f"\n{u.st.shape}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph

def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
if (found:=replaces.get(base)) is not None: return found
replaces[base] = ret = base.replace(src=tuple(replace_uop(x, replaces) for x in base.src))
replaces[base] = ret = base.replace(src=tuple(_replace_uop(x, replaces) for x in base.src))
return ret

def load_kernels(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, \
TrackedRewriteContext, Any]]]:
kernels = defaultdict(list)
for k,rewrites in contexts:
if isinstance(k, Kernel): name = to_function_name(k.name)
else: name = None
for ctx in rewrites:
if ctx.sink.op is UOps.CONST: continue
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx, k))
return kernels

@functools.lru_cache(None)
def get_src(k) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
def _prg(k:Optional[Kernel]) -> Optional[str]: return k.to_program().src if isinstance(k, Kernel) else None
def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), graphs=[_uop_to_json(ctx.sink)], diffs=[], changed_nodes=[], kernel_code=_prg(k))
replaces: Dict[UOp, UOp] = {}
sink = ctx.sink
for i,(u0,u1,upat) in enumerate(ctx.rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
replaces[u0] = u1
new_sink = _replace_uop(sink, {**replaces})
# sanity check
if new_sink is sink:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {upat.location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
g.diffs.append(list(difflib.unified_diff(str(u0).splitlines(), str(u1).splitlines())))
g.graphs.append(_uop_to_json(sink:=new_sink))
return g

# ** HTTP server

class Handler(BaseHTTPRequestHandler):
def do_GET(self):
Expand All @@ -83,17 +103,15 @@ def do_GET(self):
self.end_headers()
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
metadata, ctx, k = list(kernels.values())[int(qkernel[0])][int(query["idx"][0])]
graphs, diffs, changed_nodes = reconstruct_graph(ctx)
ret = json.dumps(asdict(GraphRewriteDetails(**asdict(metadata), graphs=list(map(uop_to_json, graphs)),
diffs=diffs, changed_nodes=changed_nodes, kernel_code=get_src(k)))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[0]), v)) for v in kernels.values()]).encode()
ret = json.dumps(asdict(get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]))).encode()
else: ret = json.dumps([list(map(lambda x:asdict(x[2]), v)) for v in kernels]).encode()
else:
self.send_response(404)
ret = b""
return self.wfile.write(ret)

BROWSER = getenv("BROWSER", 1)
# ** main loop

stop_reloader = threading.Event()
def reloader():
mtime = os.stat(__file__).st_mtime
Expand All @@ -108,18 +126,17 @@ def reloader():
print("*** viz is starting")
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[Tuple[Any, List[TrackedRewriteContext]]] = pickle.load(f)
print("*** unpickled saved rewrites")
kernels = load_kernels(contexts)
kernels = get_metadata(contexts)
if getenv("FUZZ_VIZ"):
for v in tqdm(kernels.values()):
for _,ctx,_ in v: reconstruct_graph(ctx)
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
print(f"fuzzed {len(ret)} rewrite details")
print("*** loaded kernels")
server = HTTPServer(('', 8000), Handler)
st = time.perf_counter()
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
if BROWSER: webbrowser.open("http://localhost:8000")
try:
server.serve_forever()
if getenv("BROWSER", 1): webbrowser.open("http://localhost:8000")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
stop_reloader.set()
26 changes: 0 additions & 26 deletions viz/spec.py

This file was deleted.

39 changes: 18 additions & 21 deletions viz/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import os, itertools
os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["PRINT_MATCH_STATS"] = "0"
from tinygrad import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.engine.realize import lower_schedule
from tinygrad.dtype import PtrDType
from tinygrad.helpers import Context, all_same, getenv
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps, track_rewrites
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import Context, all_same, DEBUG, getenv
from tinygrad.codegen.uopgraph import sym, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import reconstruct_graph, uop_to_json, load_kernels
from viz.spec import GraphRewriteMetadata
from viz.serve import GraphRewriteMetadata, get_metadata, get_details, _uop_to_json

def group_rewrites(kernels:List[GraphRewriteMetadata]): return {k:list(v) for k,v in itertools.groupby(kernels, lambda x:x.loc)}

Expand All @@ -22,7 +20,7 @@ def tearDown(self) -> None:

def assert_valid_ctx(self, contexts:List[Tuple[Any,List[TrackedRewriteContext]]]):
assert len(contexts) != 0
load_kernels(contexts)
get_metadata(contexts)

def assert_valid_graph(self, t):
contexts.clear()
Expand All @@ -41,10 +39,10 @@ def test_ctx_groups(self):
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
list(lower_schedule(schedule1))
list(lower_schedule(schedule2))
with Context(TRACK_MATCH_STATS=0): ret = list(load_kernels(contexts).values())
with Context(TRACK_MATCH_STATS=0): ret = get_metadata(contexts)
assert len(ret) == 3
assert all(len([x for x,_,_ in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
assert all(len([x for x,_,_ in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])
assert all(len([x for _,_,x in y if "schedule" in x.loc[0]]) == 0 for y in ret[1:])
assert all(len([x for _,_,x in y if "uopgraph" in x.loc[0]]) != 0 for y in ret[1:])

def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
Expand All @@ -64,10 +62,10 @@ def test_removed_node(self):
@track_rewrites
def f(k): return graph_rewrite(sink, pm)
ret = f("test_rewrite")
if DEBUG >= 4: print_diff(sink, ret)
graphs,_,_ = reconstruct_graph(contexts[0][1][0])
assert graphs[-1].key == ret.key
self.assert_valid_ctx(contexts)
args = get_metadata(contexts)[0][0]
g = get_details(*args)
assert g.graphs[-1] == _uop_to_json(ret)

def test_devectorize_viz(self):
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=1, upcasted=1, dont_use_locals=False), src=(
Expand Down Expand Up @@ -96,8 +94,7 @@ def test_devectorize_viz(self):
pm = sym+(devectorize+float4_folding)
@track_rewrites
def f(k): return graph_rewrite(sink, pm)
new_sink = f("test_rewrite")
if DEBUG >= 4: print_diff(sink, new_sink, unified=0)
f("test_rewrite")
self.assert_valid_ctx(contexts)
assert all(ctx.loc[0].split("/")[-1] == __file__.split("/")[-1] for _,ctxs in contexts for ctx in ctxs)

Expand All @@ -111,9 +108,9 @@ def test_dedup_ast(self):
a = Tensor.empty(4, 4).contiguous().realize()+2
b = Tensor.empty(4, 4).contiguous().realize()+2
Tensor.schedule(a, b)
with Context(TRACK_MATCH_STATS=0): kernels = load_kernels(contexts)
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)
self.assertEqual(len(kernels), 1)
rewrites = [x[0] for x in list(kernels.values())[0]]
rewrites = [x[2] for x in kernels[0]]
assert all(len(v) == 1 for k,v in group_rewrites(rewrites).items() if "schedule.py" in k)

def test_no_dedup_different_opts(self):
Expand All @@ -122,23 +119,23 @@ def test_no_dedup_different_opts(self):
s = a.schedule()
with Context(NOOPT=1): list(lower_schedule(s.copy()))
with Context(NOOPT=0): list(lower_schedule(s.copy()))
with Context(TRACK_MATCH_STATS=0): kernels = list(load_kernels(contexts).values())[1:]
with Context(TRACK_MATCH_STATS=0): kernels = get_metadata(contexts)[1:]
self.assertEqual(len(kernels), 2)
rewrites = [x[0] for x in kernels[0]]
rewrites = [x[2] for x in kernels[0]]
assert all(len(v) == 1 for _,v in group_rewrites(rewrites).items())

def test_fold_const_nodes(self):
a = Tensor.empty(4, 4)+2
contexts.clear()
sink = a.schedule()[-1].ast
ret = uop_to_json(sink)
ret = _uop_to_json(sink)
assert not any(v[0].startswith("CONST") for v in ret.values())
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1

@unittest.skip("VIZ for a single CONST isn't supported anymore")
def test_no_fold_single_const(self):
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
ret = uop_to_json(node, base=node)
ret = _uop_to_json(node, base=node)
assert len(ret) == 1

if __name__ == "__main__":
Expand Down

0 comments on commit 0ecc417

Please sign in to comment.