Skip to content

Commit

Permalink
add gated read_image count in openpilot compile2 (tinygrad#6546)
Browse files Browse the repository at this point in the history
530 to go
  • Loading branch information
chenyuxyz authored Sep 17, 2024
1 parent 665b420 commit 798be6b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=530 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)
Expand Down
6 changes: 6 additions & 0 deletions examples/openpilot/compile2.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")

saved_binaries = set()
binaries = []
gated_read_image_count = 0
GlobalCounters.reset()
with Context(DEBUG=max(DEBUG.value, 2)):
for ei in eis:
Expand All @@ -173,6 +174,7 @@ def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
jdat['binaries'].append({"name":prg.p.function_name, "length":len(prg.lib)})
binaries.append(prg.lib)
saved_binaries.add(prg.p.function_name)
gated_read_image_count += prg.p.src.count("?read_image")
ei.run()
jdat['kernels'].append({
"name": prg.p.function_name,
Expand All @@ -184,6 +186,10 @@ def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
"arg_size": [8]*len(ei.bufs),
})

if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", 0)):
assert gated_read_image_count <= allowed_gated_read_image, \
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"

output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
print(f"saving thneed to {output_fn} with {len(weights)} buffers and {len(binaries)} binaries")
with open(output_fn, "wb") as f:
Expand Down

0 comments on commit 798be6b

Please sign in to comment.