Skip to content

Commit

Permalink
refactor onnxslim test and fix bugs (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
whyb authored Nov 28, 2024
1 parent 8e0105e commit 987fab1
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 68 deletions.
5 changes: 2 additions & 3 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def _add_arguments(self):
arg_type = _get_inner_type(field_def.type)
default_value = field_def.default if field_def.default is not field_def.default_factory else None
help_text = field_def.metadata.get("help", "")
nargs = "+" if arg_type == list else None
nargs = "+" if get_origin(arg_type) == list else None
choices = field_def.metadata.get("choices", None)

if choices and default_value is not None and default_value not in choices:
raise ValueError(
f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
)

arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
if arg_type == bool:
self.parser.add_argument(
f"--{field_name.replace('_', '-')}",
Expand Down
2 changes: 1 addition & 1 deletion onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def save(

if model_info:
model_size = model.ByteSize()
model_info["model_size"] = [model_size, model_info["model_size"]]
model_info.model_size = [model_size, model_info.model_size]


def check_result(raw_onnx_output, slimmed_onnx_output):
Expand Down
165 changes: 101 additions & 64 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,132 @@
import subprocess
import tempfile

import numpy as np
import pytest

from onnxslim import slim
from onnxslim.utils import print_model_info_as_table, summarize_model
from onnxslim.utils import summarize_model

MODELZOO_PATH = "/data/modelzoo"
FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx"


class TestFunctional:
def run_basic_test(self, in_model_path, out_model_path, **kwargs):
check_func = kwargs.get("check_func", None)
kwargs_api = kwargs.get("api", {})
kwargs_bash = kwargs.get("bash", "")
summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path)
if check_func:
check_func(summary)

slim(in_model_path, out_model_path, **kwargs_api)
summary = summarize_model(out_model_path, out_model_path)
if check_func:
check_func(summary)

command = f'onnxslim "{in_model_path}" "{out_model_path}" {kwargs_bash}'

result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

summary = summarize_model(out_model_path, out_model_path)
if check_func:
check_func(summary)

def test_basic(self, request):
"""Test the basic functionality of the slim function."""
with tempfile.TemporaryDirectory() as tempdir:
summary = summarize_model(slim(FILENAME), request.node.name)
print_model_info_as_table(summary)
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name)
slim(FILENAME, output_name, model_check=True)

command = f"onnxslim {FILENAME} {output_name}"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

out_model_path = os.path.join(tempdir, "resnet18.onnx")
self.run_basic_test(FILENAME, out_model_path)

class TestFeature:
def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)
def check_func(summary):
assert summary.input_info[0].shape == (1, 3, 224, 224)

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--input-shapes input:1,3,224,224"
kwargs["api"] = {"input_shapes": ["input:1,3,224,224"]}
kwargs["check_func"] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_fp162fp32_conversion(self, request):
"""Test the conversion of an ONNX model from FP16 to FP32 precision."""
import numpy as np
def test_input_modification(self, request):
def check_func(summary):
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16")
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].dtype == np.float16
assert summary.input_info[0].shape == (1, 3, 224, 224)

slim(output_name, output_name, dtype="fp32")
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].dtype == np.float32
assert summary.input_info[0].shape == (1, 3, 224, 224)
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0"
kwargs["api"] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]}
kwargs["check_func"] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_output_modification(self, request):
"""Tests output modification."""
summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps
def check_func(summary):
assert "/Flatten_output_0" in summary.output_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, outputs=["/Flatten_output_0"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--outputs /Flatten_output_0"
kwargs["api"] = {"outputs": ["/Flatten_output_0"]}
kwargs["check_func"] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_dtype_conversion(self, request):
def check_func_fp16(summary):
assert summary.input_info[0].dtype == np.float16

def test_input_modification(self, request):
"""Tests input modification."""
summary = summarize_model(
slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]),
request.node.name,
)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
def check_func_fp32(summary):
assert summary.input_info[0].dtype == np.float32

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
kwargs = {}
kwargs["bash"] = "--dtype fp16"
kwargs["api"] = {"dtype": "fp16"}
kwargs["check_func"] = check_func_fp16
self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs)

out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
kwargs = {}
kwargs["bash"] = "--dtype fp32"
kwargs["api"] = {"dtype": "fp32"}
kwargs["check_func"] = check_func_fp32
self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs)

def test_save_as_external_data(self, request):
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--save-as-external-data"
kwargs["api"] = {"save_as_external_data": True}
self.run_basic_test(FILENAME, out_model_path, **kwargs)
assert os.path.getsize(out_model_path) < 1e5

def test_model_check(self, request):
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
input_data = os.path.join(tempdir, "input.npy")
test_data = np.random.random((1, 3, 224, 224)).astype(np.float32)
np.save(input_data, test_data)
kwargs = {}
kwargs["bash"] = f"--model-check --model-check-inputs input:{input_data}"
kwargs["api"] = {"model_check": True, "model_check_inputs": [f"input:{input_data}"]}
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_inspect(self, request):
with tempfile.TemporaryDirectory():
kwargs = {}
kwargs["bash"] = "--inspect"
kwargs["api"] = {"inspect": True}
self.run_basic_test(FILENAME, FILENAME, **kwargs)


if __name__ == "__main__":
Expand Down

0 comments on commit 987fab1

Please sign in to comment.