Skip to content

Commit

Permalink
things that are only used in one place don't belong in helpers [pr] (t…
Browse files Browse the repository at this point in the history
…inygrad#6878)

* things that are only used in one place don't belong in helpers [pr]

* pretty print moved
  • Loading branch information
geohot authored Oct 4, 2024
1 parent f4ec39f commit cdff1d7
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 79 deletions.
4 changes: 2 additions & 2 deletions extra/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import functools, hashlib
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import dedup, pretty_print, prod
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps
from tinygrad.helpers import dedup, prod
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
Expand Down
4 changes: 2 additions & 2 deletions test/test_profiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context, ProfileLogger
from tinygrad.device import Buffer, BufferOptions, HCQCompiled
from tinygrad.helpers import CI, getenv, Context
from tinygrad.device import Buffer, BufferOptions, ProfileLogger, HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner

Expand Down
4 changes: 3 additions & 1 deletion test/unit/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import gzip, unittest
from PIL import Image
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
from tinygrad.tensor import get_shape
from tinygrad.codegen.lowerer import get_contraction
from tinygrad.shape.symbolic import Variable, NumNode
import numpy as np

Expand Down
4 changes: 2 additions & 2 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, get_contraction, to_function_name, diskcache_put
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.codegen.lowerer import ast_to_uop
from tinygrad.codegen.lowerer import ast_to_uop, get_contraction

class OptOps(Enum):
TC = auto(); UPCAST = auto(); UPCASTMID = auto(); UNROLL = auto(); LOCAL = auto() # noqa: E702
Expand Down
11 changes: 9 additions & 2 deletions tinygrad/codegen/lowerer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# the job of the lowerer is to do indexing
from __future__ import annotations
import functools
import functools, itertools, operator
from dataclasses import dataclass
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
from tinygrad.helpers import all_int, prod, partition, flatten

# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
except ValueError: return None
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]

# ***** indexing *****

Expand Down
44 changes: 41 additions & 3 deletions tinygrad/device.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
import multiprocessing, decimal, statistics, random
import multiprocessing, decimal, statistics, random, json
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator, Union
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILEPATH, PROFILE
from tinygrad.dtype import DType, ImageDType
from tinygrad.renderer import Renderer

Expand Down Expand Up @@ -495,6 +495,44 @@ def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), loca
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None

class ProfileLogger:
writers: int = 0
mjson: List[Dict] = []
actors: Dict[Union[str, Tuple[str, str]], int] = {}

def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1

def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]

def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})

if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
self.actors[subactor_key] = (tid:=len(self.actors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})

return self.actors[actor_name], self.actors.get(subactor_key, -1)

def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name, st, et, actor_name, subactor_name, args in self.events:
pid, tid = self._ensure_actor(actor_name,subactor_name)
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})

for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})

ProfileLogger.writers -= 1
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")

class HCQCompiled(Compiled):
"""
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
Expand Down
66 changes: 1 addition & 65 deletions tinygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import itertools, urllib.request, subprocess, shutil, math, json, contextvars, types, copyreg, inspect, importlib
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
from tinygrad.shape.shapetracker import sint

T = TypeVar("T")
U = TypeVar("U")
Expand Down Expand Up @@ -68,21 +67,6 @@ def get_child(obj, key):
else: obj = getattr(obj, k)
return obj

def get_shape(x) -> Tuple[int, ...]:
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
subs = [get_shape(xi) for xi in x]
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
slen = 1 if aapi else len(subs)
return (slen,) + (subs[0] if subs else ())

# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:Tuple[sint, ...], new_shape:Tuple[sint, ...]) -> Optional[List[List[int]]]:
acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list(itertools.accumulate(new_shape, operator.mul))
try: split = [acc_old.index(acc)+1 if acc != 1 else 0 for acc in acc_new]
except ValueError: return None
return [list(range(st,ed)) for st,ed in zip([0]+split[:-1], split[:-1]+[len(old_shape)])]

@functools.lru_cache(maxsize=None)
def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)])
@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -171,44 +155,6 @@ def __exit__(self, *exc):
colored(_format_fcn(fcn).ljust(50), "yellow"),
colored(f"<- {(scallers[0][1][2]/tottime)*100:3.0f}% {_format_fcn(scallers[0][0])}", "BLACK") if scallers else '')

class ProfileLogger:
writers: int = 0
mjson: List[Dict] = []
actors: Dict[Union[str, Tuple[str, str]], int] = {}

def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1

def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]

def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})

if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
self.actors[subactor_key] = (tid:=len(self.actors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})

return self.actors[actor_name], self.actors.get(subactor_key, -1)

def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name, st, et, actor_name, subactor_name, args in self.events:
pid, tid = self._ensure_actor(actor_name,subactor_name)
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})

for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})

ProfileLogger.writers -= 1
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")

# *** universal database cache ***

_cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache"))
Expand Down Expand Up @@ -363,16 +309,6 @@ def SI(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))
class trange(tqdm):
def __init__(self, n:int, **kwargs): super().__init__(iterable=range(n), total=n, **kwargs)

def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
def dfs(x:Any, cache:dict):
for s in srcfn(x) or []:
cache.setdefault(s, [len(cache), 0, False])[1] += 1
if cache[s][1] == 1: dfs(s, cache)
if cache is None: dfs(x, cache:={})
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs

# *** universal support for code object pickling

def _reconstruct_code(*args): return types.CodeType(*args)
Expand Down
13 changes: 12 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass, field
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, pretty_print, prod, getenv, all_same
from tinygrad.helpers import ContextVar, prod, getenv, all_same
if TYPE_CHECKING:
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.shapetracker import ShapeTracker
Expand Down Expand Up @@ -159,6 +159,17 @@ def resolve(x, default:bool=True):
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop

# used for UOp and UPat
def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->str:
def dfs(x:Any, cache:dict):
for s in srcfn(x) or []:
cache.setdefault(s, [len(cache), 0, False])[1] += 1
if cache[s][1] == 1: dfs(s, cache)
if cache is None: dfs(x, cache:={})
if (cx:=cache.setdefault(x, [0,0,False]))[2]: return f"{' '*d} x{cx[0]}"
cx[2], srcs = True, ('None' if srcfn(x) is None else ''.join(f'\n{pretty_print(s, rep, srcfn, cache, d+2)},' for s in srcfn(x)))
return f"{' '*d}{f'x{cx[0]}:=' * (cx[1]>1)}{rep(x)}" % srcs

ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
class UOp(MathTrait):
def __reduce__(self): return UOp, (self.op, self.dtype, self.src, self.arg)
Expand Down
10 changes: 9 additions & 1 deletion tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict

from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
Expand Down Expand Up @@ -57,6 +57,14 @@ def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noq
del ret.srcs
return ret

def get_shape(x) -> Tuple[int, ...]:
if not hasattr(x, "__len__") or not hasattr(x, "__getitem__") or isinstance(x, str): return ()
if (aapi := (hasattr(x, "shape") and x.shape == ())): return ()
subs = [get_shape(xi) for xi in x]
if not all_same(subs): raise ValueError(f"inhomogeneous shape from {x}")
slen = 1 if aapi else len(subs)
return (slen,) + (subs[0] if subs else ())

def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x
else:
Expand Down

0 comments on commit cdff1d7

Please sign in to comment.