Skip to content

Commit

Permalink
move checks into compile3, delete compile2 [pr] (tinygrad#8127)
Browse files Browse the repository at this point in the history
* move checks into compile3 [pr]

* test_vs_onnx

* test v torch works

* float16 won't compile on compile3

* actually delete compile2
  • Loading branch information
geohot authored Dec 9, 2024
1 parent 3582879 commit f83d715
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 242 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,6 @@ jobs:
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: openpilot compile 0.9.4
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python examples/openpilot/compile2.py | tee openpilot_compile_0_9_4.txt
- name: openpilot compile 0.9.7
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_compile_0_9_7.txt
- name: validate openpilot 0.9.7
run: PYTHONPATH=. FLOAT16=0 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_7.txt
- name: benchmark openpilot 0.9.4
Expand Down
15 changes: 4 additions & 11 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,15 @@ jobs:
PYTHONPATH="." GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py --durations=20
PYTHONPATH="." GPU=1 IMAGE=2 python3 test/models/test_end2end.py TestEnd2End.test_linear_mnist
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot compile3
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot fastvits model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'onnx' }}
name: Test ONNX (GPU)
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
Expand Down
211 changes: 0 additions & 211 deletions examples/openpilot/compile2.py

This file was deleted.

65 changes: 51 additions & 14 deletions examples/openpilot/compile3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
from tinygrad.helpers import DEBUG, getenv
from tinygrad.tensor import _from_np_dtype
from tinygrad.engine.realize import CompiledRunner

import onnx
from onnx.helper import tensor_dtype_to_np_dtype
Expand All @@ -16,12 +17,11 @@
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = "/tmp/openpilot.pkl"

def compile():
def compile(onnx_file):
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False

onnx_bytes = fetch(OPENPILOT_MODEL)
onnx_model = onnx.load(onnx_bytes)
run_onnx = get_run_onnx(onnx_model)
print("loaded model")

Expand All @@ -48,23 +48,29 @@ def compile():
np.testing.assert_equal(test_val, ret, "JIT run failed")
print("jit run validated")

# checks from compile2
kernel_count = 0
gated_read_image_count = 0
for ei in run_onnx_jit.captured.jit_cache:
if isinstance(ei.prg, CompiledRunner):
kernel_count += 1
gated_read_image_count += ei.prg.p.src.count("?read_image")
print(f"kernel_count: {kernel_count} gated_read_image_count: {gated_read_image_count}")
assert kernel_count <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
assert gated_read_image_count <= allowed_gated_read_image, \
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"

with open(OUTPUT, "wb") as f:
pickle.dump(run_onnx_jit, f)
mdl_sz = os.path.getsize(onnx_bytes)
mdl_sz = os.path.getsize(onnx_file)
pkl_sz = os.path.getsize(OUTPUT)
print(f"mdl size is {mdl_sz/1e6:.2f}M")
print(f"pkl size is {pkl_sz/1e6:.2f}M")
print("**** compile done ****")
return test_val

def test(test_val=None):
with open(OUTPUT, "rb") as f:
run = pickle.load(f)

# same randomness as above
Tensor.manual_seed(100)
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
sorted(zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device))}
def test_vs_compile(run, new_inputs, test_val=None):
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}

# create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory)
Expand All @@ -88,8 +94,39 @@ def test(test_val=None):
out = run(**inputs)
changed_val = out.numpy()
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
return val

def test_vs_onnx(new_inputs, test_val, onnx_file):
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
onnx_model = onnx.load(onnx_file)

if getenv("ORT"):
# test with onnxruntime
import onnxruntime as ort
onnx_session = ort.InferenceSession(onnx_file)
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
new_torch_out = onnx_output[0]
print("got ort outputs")
else:
# test with torch
from test.models.test_onnx import run_onnx_torch
# NOTE: we have to correct the order here
new_torch_out = run_onnx_torch(onnx_model, {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}).numpy()
print("got torch outputs")

np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
print("test vs onnx passed")

if __name__ == "__main__":
test_val = compile() if not getenv("RUN") else None
test(test_val)
onnx_file = fetch(OPENPILOT_MODEL)
test_val = compile(onnx_file) if not getenv("RUN") else None

with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)

# same randomness as compile
Tensor.manual_seed(100)
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}

test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file)
2 changes: 0 additions & 2 deletions examples/openpilot/go.sh

This file was deleted.

0 comments on commit f83d715

Please sign in to comment.