Skip to content

Commit

Permalink
viz touchups [pr] (tinygrad#7095)
Browse files Browse the repository at this point in the history
* viz touchups [pr]

* check if port busy

* url
  • Loading branch information
Qazalin authored Oct 16, 2024
1 parent 6172b42 commit 568a4b5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 31 deletions.
9 changes: 0 additions & 9 deletions tinygrad/viz/eslint.config.mjs

This file was deleted.

31 changes: 13 additions & 18 deletions tinygrad/viz/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
}
.node rect {
stroke: #4a4b57;
stroke-width: 1.5px;
stroke-width: 1.4px;
rx: 8px;
ry: 8px;
}
Expand All @@ -81,29 +81,23 @@
width: 100%;
height: 100%;
}
.main-container > * + * {
margin-left: 8px;
}
.container {
background-color: #0f1018;
height: 100%;
border-radius: 8px;
padding: 8px;
position: relative;
}
.container > * + * {
.container > * + *, .rewrite-container > * + * {
margin-top: 12px;
}
.rewrite-container > * + * {
margin-top: 12px;
}
.main-container > * + * {
margin-left: 8px;
}
.kernel-list {
width: 10%;
overflow-y: auto;
}
.kernel-list > ul > ul {
padding-left: 4px;
}
.kernel-list > ul > * + * {
margin-top: 4px;
}
Expand Down Expand Up @@ -153,7 +147,8 @@
</div>
<script>
/* global hljs, dagreD3, d3, DOMPurify */
// extra definitions for UOps

// **** hljs extra definitions for UOps and float4
hljs.registerLanguage("python", (hljs) => ({
...hljs.getLanguage("python"),
case_insensitive: false,
Expand All @@ -165,12 +160,12 @@
...hljs.getLanguage("python").contains,
]
}));
// extra definitions for float4
hljs.registerLanguage("cpp", (hljs) => ({
...hljs.getLanguage('cpp'),
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));

// **** D3
function recenterRects(svg, zoom) {
const svgBounds = svg.node().getBoundingClientRect();
for (const rect of svg.node().querySelectorAll("rect")) {
Expand All @@ -182,7 +177,6 @@
}
svg.call(zoom.transform, d3.zoomIdentity)
}

function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
Expand All @@ -209,11 +203,10 @@
render(inner, g);
}

function toPath(loc) {
const [fp, lineno] = loc;
return `${fp.split("/").pop()}:${lineno}`
}
// **** extra helpers
const toPath = ([fp, lineno]) => `${fp.split("/").pop()}:${lineno}`;

// **** main loop
var ret = {};
var cache = {};
var kernels = null;
Expand Down Expand Up @@ -359,6 +352,8 @@
metadata.appendChild(Object.assign(document.createElement("p"), { textContent: `No rewrites in ${toPath(ret.loc)}.` }));
}
}

// **** keyboard shortcuts
document.addEventListener("keydown", async function(event) {
// up and down change the UOp or kernel from the list
if (!expandKernel) {
Expand Down
11 changes: 7 additions & 4 deletions tinygrad/viz/serve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -122,6 +122,9 @@ def reloader():
time.sleep(0.1)

if __name__ == "__main__":
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(((HOST:="http://127.0.0.1").replace("http://", ""), PORT:=getenv("PORT", 8000))) == 0:
raise RuntimeError(f"{HOST}:{PORT} is occupied! use PORT= to change.")
stop_reloader = threading.Event()
multiprocessing.current_process().name = "VizProcess" # disallow opening of devices
st = time.perf_counter()
Expand All @@ -133,12 +136,12 @@ def reloader():
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
print(f"fuzzed {len(ret)} rewrite details")
print("*** loaded kernels")
server = HTTPServer(('', PORT:=getenv("PORT", 8000)), Handler)
server = HTTPServer(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(f"*** started viz on http://127.0.0.1:{PORT}")
print(f"*** started viz on {HOST}:{PORT}")
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"))
if getenv("BROWSER", 0): webbrowser.open(f"http://127.0.0.1:{PORT}")
if getenv("BROWSER", 0): webbrowser.open(f"{HOST}:{PORT}")
try: server.serve_forever()
except KeyboardInterrupt:
print("*** viz is shutting down...")
Expand Down

0 comments on commit 568a4b5

Please sign in to comment.