Skip to content

Commit

Permalink
prune independent kernels for openpilot [run_process_replay] (tinygra…
Browse files Browse the repository at this point in the history
…d#6569)

* prune independent kernels for openpilot [run_process_replay]

* new pruning

* prune first, then memory plan
  • Loading branch information
geohot authored Sep 17, 2024
1 parent 9295bc0 commit 28e565d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
13 changes: 7 additions & 6 deletions test/external/external_benchmark_openpilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_types = {inp.name: tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in input_shapes.items()}
new_inputs_np = {k:inp.numpy() for k,inp in new_inputs.items()}
new_inputs_junk = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in input_shapes.items()}

# benchmark
for _ in range(5):
GlobalCounters.reset()
st = time.perf_counter_ns()
ret = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy()
ret = next(iter(run_onnx(new_inputs_junk).values())).cast(dtypes.float32).numpy()
print(f"unjitted: {(time.perf_counter_ns() - st)*1e-6:7.4f} ms")

run_onnx_jit = TinyJit(run_onnx)
# NOTE: the inputs to a JIT must be first level arguments
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
for _ in range(10):
GlobalCounters.reset()
st = time.perf_counter_ns()
ret = next(iter(run_onnx_jit(new_inputs).values())).cast(dtypes.float32).numpy()
ret = next(iter(run_onnx_jit(**new_inputs_junk).values())).cast(dtypes.float32).numpy()
print(f"jitted: {(time.perf_counter_ns() - st)*1e-6:7.4f} ms")

suffix = ""
Expand All @@ -43,14 +44,14 @@
path = Path(__file__).parent / "openpilot" / f"{hashlib.md5(OPENPILOT_MODEL.encode()).hexdigest()}{suffix}.npy"

# validate if we have records
tinygrad_out = next(iter(run_onnx_jit(new_inputs).values())).cast(dtypes.float32).numpy()
tinygrad_out = next(iter(run_onnx_jit(**new_inputs).values())).cast(dtypes.float32).numpy()
if getenv("SAVE_OUTPUT"):
np.save(path, tinygrad_out)
print(f"saved output to {path}!")
elif getenv("FUZZ") and path.exists():
known_good_out = np.load(path)
for _ in trange(1000):
ret = next(iter(run_onnx_jit(new_inputs).values())).cast(dtypes.float32).numpy()
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast(dtypes.float32).numpy()
np.testing.assert_allclose(known_good_out, ret, atol=1e-2, rtol=1e-2)
print(colored("fuzz validated!", "green"))
elif path.exists():
Expand Down
24 changes: 21 additions & 3 deletions tinygrad/engine/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, colored, JIT, dedup
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, colored, JIT, dedup, partition
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
Expand Down Expand Up @@ -183,11 +183,12 @@ def _prepare_jit_inputs(args, kwargs):
return input_buffers, var_vals, names, st_vars_dtype_device

class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None):
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None, prune=False):
assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn
self.captured: Optional[CapturedJit] = captured
self.cnt: int = 2 if self.fxn is None else 0
self.prune = prune

def add_buffer(self, b:Buffer) -> Buffer:
if found:=self._buffer_replace.get(b, None): return found
Expand Down Expand Up @@ -232,7 +233,8 @@ def __call__(self, *args, **kwargs) -> ReturnType:
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._jit_cache: List[ExecItem] = []
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
# TODO: should we always disable the memory planner here? it must be off for prune
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value), NO_MEMORY_PLANNER=int(self.prune)):
capturing.append(self)
try:
ret = self.fxn(*args, **kwargs)
Expand All @@ -253,6 +255,22 @@ def __call__(self, *args, **kwargs) -> ReturnType:
input_buffers.append(b)
extra_view_inputs.append((input_buffers.index(b.base), b.offset, b.device, b.size, b.dtype))

# prune independent kernels (optional)
if self.prune:
depends = set(input_buffers)
for ei in jit_cache:
if any(b in depends for b in ei.bufs):
if isinstance(ei.prg, CompiledRunner):
for out in ei.prg.p.outs: depends.add(cast(Buffer, ei.bufs[out]))
pruned, onetime = partition(jit_cache,
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
# run the onetime kernels here
for ei in onetime:
for b in ei.bufs: cast(Buffer, b).ensure_allocated()
ei.run(var_vals, jit=True)
jit_cache = pruned

# memory planning (optional)
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, BufferXfer) for b in ji.bufs}
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/engine/realize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
from tinygrad.helpers import NO_MEMORY_PLANNER
from tinygrad.ops import MetaOps, UOps, UOp
from tinygrad.dtype import dtypes
from tinygrad.device import Device, Buffer
Expand Down Expand Up @@ -225,7 +226,7 @@ def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, i
# **************** memory planning ****************

def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
if getenv("NO_MEMORY_PLANNER"): return {}
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance = {}, {}
for i,u in enumerate(buffers):
for buf in u:
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __lt__(self, x): return self.value < x
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, AST_REWRITE = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("AST_REWRITE", 1)
SPLIT_REDUCEOP, AST_REWRITE, NO_MEMORY_PLANNER = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("AST_REWRITE", 1), ContextVar("NO_MEMORY_PLANNER", 0)

@dataclass(frozen=True)
class Metadata:
Expand Down

0 comments on commit 28e565d

Please sign in to comment.