Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor onnxslim test and fix bugs #51

Merged
merged 12 commits into from
Nov 28, 2024
5 changes: 2 additions & 3 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def _add_arguments(self):
arg_type = _get_inner_type(field_def.type)
default_value = field_def.default if field_def.default is not field_def.default_factory else None
help_text = field_def.metadata.get("help", "")
nargs = "+" if arg_type == list else None
nargs = "+" if get_origin(arg_type) == list else None
choices = field_def.metadata.get("choices", None)

if choices and default_value is not None and default_value not in choices:
raise ValueError(
f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
)

arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
if arg_type == bool:
self.parser.add_argument(
f"--{field_name.replace('_', '-')}",
Expand Down
2 changes: 1 addition & 1 deletion onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def save(

if model_info:
model_size = model.ByteSize()
model_info["model_size"] = [model_size, model_info["model_size"]]
model_info.model_size = [model_size, model_info.model_size]


def check_result(raw_onnx_output, slimmed_onnx_output):
Expand Down
165 changes: 101 additions & 64 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,95 +2,132 @@
import subprocess
import tempfile

import numpy as np
import pytest

from onnxslim import slim
from onnxslim.utils import print_model_info_as_table, summarize_model
from onnxslim.utils import summarize_model

MODELZOO_PATH = "/data/modelzoo"
FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx"


class TestFunctional:
def run_basic_test(self, in_model_path, out_model_path, **kwargs):
check_func = kwargs.get("check_func", None)
kwargs_api = kwargs.get("api", {})
kwargs_bash = kwargs.get("bash", "")
summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path)
if check_func:
check_func(summary)

slim(in_model_path, out_model_path, **kwargs_api)
summary = summarize_model(out_model_path, out_model_path)
if check_func:
check_func(summary)

command = f'onnxslim "{in_model_path}" "{out_model_path}" {kwargs_bash}'

result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

summary = summarize_model(out_model_path, out_model_path)
if check_func:
check_func(summary)

def test_basic(self, request):
"""Test the basic functionality of the slim function."""
with tempfile.TemporaryDirectory() as tempdir:
summary = summarize_model(slim(FILENAME), request.node.name)
print_model_info_as_table(summary)
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name)
slim(FILENAME, output_name, model_check=True)

command = f"onnxslim {FILENAME} {output_name}"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

out_model_path = os.path.join(tempdir, "resnet18.onnx")
self.run_basic_test(FILENAME, out_model_path)

class TestFeature:
def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)
def check_func(summary):
assert summary.input_info[0].shape == (1, 3, 224, 224)

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--input-shapes input:1,3,224,224"
kwargs["api"] = {"input_shapes": ["input:1,3,224,224"]}
kwargs["check_func"] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_fp162fp32_conversion(self, request):
"""Test the conversion of an ONNX model from FP16 to FP32 precision."""
import numpy as np
def test_input_modification(self, request):
def check_func(summary):
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16")
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].dtype == np.float16
assert summary.input_info[0].shape == (1, 3, 224, 224)

slim(output_name, output_name, dtype="fp32")
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].dtype == np.float32
assert summary.input_info[0].shape == (1, 3, 224, 224)
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0"
kwargs["api"] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]}
kwargs["check_func"] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_output_modification(self, request):
"""Tests output modification."""
summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps
def check_func(summary):
assert "/Flatten_output_0" in summary.output_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, outputs=["/Flatten_output_0"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--outputs /Flatten_output_0"
kwargs["api"] = {"outputs": ["/Flatten_output_0"]}
kwargs["check_func"] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_dtype_conversion(self, request):
def check_func_fp16(summary):
assert summary.input_info[0].dtype == np.float16

def test_input_modification(self, request):
"""Tests input modification."""
summary = summarize_model(
slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]),
request.node.name,
)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
def check_func_fp32(summary):
assert summary.input_info[0].dtype == np.float32

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
kwargs = {}
kwargs["bash"] = "--dtype fp16"
kwargs["api"] = {"dtype": "fp16"}
kwargs["check_func"] = check_func_fp16
self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs)

out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
kwargs = {}
kwargs["bash"] = "--dtype fp32"
kwargs["api"] = {"dtype": "fp32"}
kwargs["check_func"] = check_func_fp32
self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs)

def test_save_as_external_data(self, request):
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs["bash"] = "--save-as-external-data"
kwargs["api"] = {"save_as_external_data": True}
self.run_basic_test(FILENAME, out_model_path, **kwargs)
assert os.path.getsize(out_model_path) < 1e5

def test_model_check(self, request):
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
input_data = os.path.join(tempdir, "input.npy")
test_data = np.random.random((1, 3, 224, 224)).astype(np.float32)
np.save(input_data, test_data)
kwargs = {}
kwargs["bash"] = f"--model-check --model-check-inputs input:{input_data}"
kwargs["api"] = {"model_check": True, "model_check_inputs": [f"input:{input_data}"]}
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_inspect(self, request):
with tempfile.TemporaryDirectory():
kwargs = {}
kwargs["bash"] = "--inspect"
kwargs["api"] = {"inspect": True}
self.run_basic_test(FILENAME, FILENAME, **kwargs)


if __name__ == "__main__":
Expand Down