Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor onnxslim test and fix bugs #51

Merged
merged 12 commits into from
Nov 28, 2024
5 changes: 2 additions & 3 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def _add_arguments(self):
arg_type = _get_inner_type(field_def.type)
default_value = field_def.default if field_def.default is not field_def.default_factory else None
help_text = field_def.metadata.get("help", "")
nargs = "+" if arg_type == list else None
nargs = "+" if get_origin(arg_type) == list else None
choices = field_def.metadata.get("choices", None)

if choices and default_value is not None and default_value not in choices:
raise ValueError(
f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
)

arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
if arg_type == bool:
self.parser.add_argument(
f"--{field_name.replace('_', '-')}",
Expand Down
43 changes: 31 additions & 12 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,40 @@


class TestFunctional:
def __test_command_basic(self, request, in_model_path, out_model_path, arg_str=""):
summary = summarize_model(slim(in_model_path), request.node.name)
print_model_info_as_table(summary)
slim(in_model_path, out_model_path)
slim(in_model_path, out_model_path, model_check=True)
command = f'onnxslim "{in_model_path}" "{out_model_path}" {arg_str}'
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

def test_basic(self, request):
"""Test the basic functionality of the slim function."""
with tempfile.TemporaryDirectory() as tempdir:
summary = summarize_model(slim(FILENAME), request.node.name)
print_model_info_as_table(summary)
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name)
slim(FILENAME, output_name, model_check=True)

command = f"onnxslim {FILENAME} {output_name}"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0
out_model_path = os.path.join(tempdir, "resnet18.onnx")
self.__test_command_basic(request, FILENAME, out_model_path, "")

def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
input_shape_arg_str = "--input-shapes input:1,3,224,224"
self.__test_command_basic(request, FILENAME, out_model_path, input_shape_arg_str)

def test_fp322fp16_conversion(self, request):
"""Test the conversion of an ONNX model from FP32 to FP16 precision."""
with tempfile.TemporaryDirectory() as tempdir:
out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
dtype_fp16_arg_str = "--dtype fp16"
dtype_fp32_arg_str = "--dtype fp32"
self.__test_command_basic(request, FILENAME, out_fp16_model_path, dtype_fp16_arg_str)
self.__test_command_basic(request, out_fp16_model_path, out_fp32_model_path, dtype_fp32_arg_str)


class TestFeature:
Expand Down