Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor onnxslim test and fix bugs #51

Merged
merged 12 commits into from
Nov 28, 2024
5 changes: 2 additions & 3 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,13 @@ def _add_arguments(self):
arg_type = _get_inner_type(field_def.type)
default_value = field_def.default if field_def.default is not field_def.default_factory else None
help_text = field_def.metadata.get("help", "")
nargs = "+" if arg_type == list else None
nargs = "+" if get_origin(arg_type) == list else None
choices = field_def.metadata.get("choices", None)

if choices and default_value is not None and default_value not in choices:
raise ValueError(
f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
)

arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
if arg_type == bool:
self.parser.add_argument(
f"--{field_name.replace('_', '-')}",
Expand Down
31 changes: 23 additions & 8 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,36 @@


class TestFunctional:
def test_basic(self, request):
"""Test the basic functionality of the slim function."""
def __test_command_basic(self, request, in_model_path=FILENAME, out_model_name="resnet18.onnx", arg_str=""):
with tempfile.TemporaryDirectory() as tempdir:
summary = summarize_model(slim(FILENAME), request.node.name)
summary = summarize_model(slim(in_model_path), request.node.name)
print_model_info_as_table(summary)
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name)
slim(FILENAME, output_name, model_check=True)

command = f"onnxslim {FILENAME} {output_name}"
out_model_path = os.path.join(tempdir, out_model_name)
slim(in_model_path, out_model_path)
slim(in_model_path, out_model_path, model_check=True)
command = f'onnxslim {arg_str} "{in_model_path}" "{out_model_path}"'
whyb marked this conversation as resolved.
Show resolved Hide resolved
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0
return in_model_path, out_model_path

def test_basic(self, request):
"""Test the basic functionality of the slim function."""
self.__test_command_basic(request)

def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
input_shape_arg_str = "--input-shapes input:1,3,224,224"
self.__test_command_basic(request, FILENAME, "resnet18.onnx", input_shape_arg_str)

def test_fp322fp16_conversion(self, request):
"""Test the conversion of an ONNX model from FP32 to FP16 precision."""
dtype_fp16_arg_str = "--dtype fp16"
_, out_model_path = self.__test_command_basic(request, FILENAME, "resnet18_fp16.onnx", dtype_fp16_arg_str)
dtype_fp32_arg_str = "--dtype fp32"
self.__test_command_basic(request, out_model_path, "resnet18_fp32.onnx", dtype_fp32_arg_str)


class TestFeature:
Expand Down
Loading