Skip to content

Commit

Permalink
VIZ prep for the new kernel render (tinygrad#6800)
Browse files Browse the repository at this point in the history
* refactor to list

* remove prints in test_viz

* more cleanup
  • Loading branch information
Qazalin authored Sep 29, 2024
1 parent 01c9653 commit 3c15e64
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
14 changes: 7 additions & 7 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,13 +533,13 @@ def print_match_stats():
if getenv("VIZ"):
os.environ["VIZ"] = "0"
os.execv(sys.executable, [sys.executable] + [os.path.join(os.path.dirname(__file__), "..", "viz", "serve.py")])
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")

if getenv("PRINT_MATCH_STATS", 1):
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")

# *** simple graph rewrite engine ***

Expand Down
22 changes: 10 additions & 12 deletions viz/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
# **** /graph - detailed UOp + rewrites

# NOTE: UPats in ops.py are spec
# TODO: fix key for uop with buffer
def graph_rewrites(ctx:TrackedRewriteContext):
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py" and not ("schedule" in ctx.loc[0] and "DEFINE_GLOBAL" in str(x[2]))]
return [x for x in ctx.rewrites if x[2].location[0].split("/")[-1] != "ops.py"]

@dataclass(frozen=True)
class RewriteLocation:
Expand Down Expand Up @@ -61,13 +60,13 @@ def from_ctx(ctx:TrackedRewriteContext) -> UOpRet:
extra.append([str(new_sink)])
return UOpRet(RewriteLocation.from_ctx(ctx), uops, diffs, extra, additions)
def to_json(self) -> Dict:
return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(lambda x:uop_to_json(x, self.graphs[0]), self.graphs))}
return {**asdict(self), "loc":self.loc.to_json(), "graphs": list(map(uop_to_json, self.graphs))}

def uop_to_json(x:UOp, base:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
if u.op is UOps.CONST and u is not base: continue
if u.op is UOps.CONST: continue
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
Expand All @@ -89,9 +88,8 @@ def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
class KernelRet:
name: str
code: str
ctxs: Dict[Tuple[Tuple, bytes], TrackedRewriteContext]
def to_json(self) -> Dict:
return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs.values()]}
ctxs: List[TrackedRewriteContext]
def to_json(self) -> Dict: return {"name":self.name, "code":self.code, "ctxs":[RewriteLocation.from_ctx(x).to_json() for x in self.ctxs]}

def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
ret: Dict[str, KernelRet] = {}
Expand All @@ -102,8 +100,8 @@ def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
si_ctx = ScheduleItemContext(bufs=tuple(x.arg for x in ctx.sink.sparents if x.op is UOps.BUFFER))
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink, si_ctx)).p).name, prg.src
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, [])
ret[k].ctxs.append(ctx)
return list(ret.values())

class Handler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -131,8 +129,8 @@ def do_GET(self):
self.send_header("Content-type", "application/json")
self.end_headers()
k = kernels[int(query["kernel_idx"][0])]
g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])])
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode()
g = UOpRet.from_ctx(k.ctxs[int(query["uop_idx"][0])])
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs])).encode()
else:
self.send_response(404)
ret = b""
Expand Down
11 changes: 6 additions & 5 deletions viz/test_viz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import os, itertools
os.environ["TRACK_MATCH_STATS"] = "2"
os.environ["PRINT_MATCH_STATS"] = "0"
from extra.models.resnet import ResNet50
from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
Expand All @@ -11,7 +12,7 @@
from test.external.process_replay.helpers import print_diff
from viz.serve import KernelRet, UOpRet, load_kernels, uop_to_json

def group_rewrites(kernels:KernelRet): return {k:list(v) for k,v in itertools.groupby(kernels.ctxs.values(), lambda x:x.loc)}
def group_rewrites(kernels:KernelRet): return {k:list(v) for k,v in itertools.groupby(kernels.ctxs, lambda x:x.loc)}

class TestViz(unittest.TestCase):
def tearDown(self) -> None:
Expand Down Expand Up @@ -48,8 +49,8 @@ def test_ctx_groups(self):
list(lower_schedule(schedule2))
ret = load_kernels(contexts)
assert len(ret) == 2
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc[0]]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc[0]]) != 0 for y in ret)
assert all(len([x for x in y.ctxs if "schedule" in x.loc[0]]) != 0 for y in ret)
assert all(len([x for x in y.ctxs if "uopgraph" in x.loc[0]]) != 0 for y in ret)

def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
Expand Down Expand Up @@ -140,11 +141,11 @@ def test_fold_const_nodes(self):
a = Tensor.empty(4, 4)+2
contexts.clear()
sink = a.schedule()[-1].ast
ret = uop_to_json(sink, base=sink)
for v in ret.values(): print(v)
ret = uop_to_json(sink)
assert not any(v[0].startswith("CONST") for v in ret.values())
assert len([x for x in ret.values() if "CONST" in x[0]]) == 1

@unittest.skip("VIZ for a single CONST isn't supported anymore")
def test_no_fold_single_const(self):
node = UOp(UOps.CONST, dtypes.float, (), 1.0)
ret = uop_to_json(node, base=node)
Expand Down

0 comments on commit 3c15e64

Please sign in to comment.