diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index 95ee4fc..32ac43f 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -136,14 +136,13 @@ def _add_arguments(self): arg_type = _get_inner_type(field_def.type) default_value = field_def.default if field_def.default is not field_def.default_factory else None help_text = field_def.metadata.get("help", "") - nargs = "+" if arg_type == list else None + nargs = "+" if get_origin(arg_type) == list else None choices = field_def.metadata.get("choices", None) - if choices and default_value is not None and default_value not in choices: raise ValueError( f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}." ) - + arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type if arg_type == bool: self.parser.add_argument( f"--{field_name.replace('_', '-')}", diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 4b8c3e4..a4631f3 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -504,7 +504,7 @@ def save( if model_info: model_size = model.ByteSize() - model_info["model_size"] = [model_size, model_info["model_size"]] + model_info.model_size = [model_size, model_info.model_size] def check_result(raw_onnx_output, slimmed_onnx_output): diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index 515f65e..dd7a961 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -2,95 +2,132 @@ import subprocess import tempfile +import numpy as np import pytest from onnxslim import slim -from onnxslim.utils import print_model_info_as_table, summarize_model +from onnxslim.utils import summarize_model MODELZOO_PATH = "/data/modelzoo" FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx" class TestFunctional: + def run_basic_test(self, in_model_path, out_model_path, **kwargs): + check_func = kwargs.get("check_func", None) + kwargs_api = kwargs.get("api", {}) + kwargs_bash = kwargs.get("bash", "") + summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path) + if check_func: + check_func(summary) + + slim(in_model_path, out_model_path, **kwargs_api) + summary = summarize_model(out_model_path, out_model_path) + if check_func: + check_func(summary) + + command = f'onnxslim "{in_model_path}" "{out_model_path}" {kwargs_bash}' + + result = subprocess.run(command, shell=True, capture_output=True, text=True) + output = result.stderr.strip() + # Assert the expected return code + print(output) + assert result.returncode == 0 + + summary = summarize_model(out_model_path, out_model_path) + if check_func: + check_func(summary) + def test_basic(self, request): - """Test the basic functionality of the slim function.""" with tempfile.TemporaryDirectory() as tempdir: - summary = summarize_model(slim(FILENAME), request.node.name) - print_model_info_as_table(summary) - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name) - slim(FILENAME, output_name, model_check=True) - - command = f"onnxslim {FILENAME} {output_name}" - result = subprocess.run(command, shell=True, capture_output=True, text=True) - output = result.stderr.strip() - # Assert the expected return code - print(output) - assert result.returncode == 0 - + out_model_path = os.path.join(tempdir, "resnet18.onnx") + self.run_basic_test(FILENAME, out_model_path) -class TestFeature: def test_input_shape_modification(self, request): - """Test the modification of input shapes.""" - summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name) - print_model_info_as_table(summary) - assert summary.input_info[0].shape == (1, 3, 224, 224) + def check_func(summary): + assert summary.input_info[0].shape == (1, 3, 224, 224) with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"]) - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert summary.input_info[0].shape == (1, 3, 224, 224) + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs["bash"] = "--input-shapes input:1,3,224,224" + kwargs["api"] = {"input_shapes": ["input:1,3,224,224"]} + kwargs["check_func"] = check_func + self.run_basic_test(FILENAME, out_model_path, **kwargs) - def test_fp162fp32_conversion(self, request): - """Test the conversion of an ONNX model from FP16 to FP32 precision.""" - import numpy as np + def test_input_modification(self, request): + def check_func(summary): + assert "/maxpool/MaxPool_output_0" in summary.input_maps + assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16") - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert summary.input_info[0].dtype == np.float16 - assert summary.input_info[0].shape == (1, 3, 224, 224) - - slim(output_name, output_name, dtype="fp32") - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert summary.input_info[0].dtype == np.float32 - assert summary.input_info[0].shape == (1, 3, 224, 224) + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs["bash"] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0" + kwargs["api"] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]} + kwargs["check_func"] = check_func + self.run_basic_test(FILENAME, out_model_path, **kwargs) def test_output_modification(self, request): - """Tests output modification.""" - summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name) - print_model_info_as_table(summary) - assert "/Flatten_output_0" in summary.output_maps + def check_func(summary): + assert "/Flatten_output_0" in summary.output_maps with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, outputs=["/Flatten_output_0"]) - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert "/Flatten_output_0" in summary.output_maps + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs["bash"] = "--outputs /Flatten_output_0" + kwargs["api"] = {"outputs": ["/Flatten_output_0"]} + kwargs["check_func"] = check_func + self.run_basic_test(FILENAME, out_model_path, **kwargs) + + def test_dtype_conversion(self, request): + def check_func_fp16(summary): + assert summary.input_info[0].dtype == np.float16 - def test_input_modification(self, request): - """Tests input modification.""" - summary = summarize_model( - slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]), - request.node.name, - ) - print_model_info_as_table(summary) - assert "/maxpool/MaxPool_output_0" in summary.input_maps - assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps + def check_func_fp32(summary): + assert summary.input_info[0].dtype == np.float32 with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]) - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert "/maxpool/MaxPool_output_0" in summary.input_maps - assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps + out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx") + kwargs = {} + kwargs["bash"] = "--dtype fp16" + kwargs["api"] = {"dtype": "fp16"} + kwargs["check_func"] = check_func_fp16 + self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs) + + out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx") + kwargs = {} + kwargs["bash"] = "--dtype fp32" + kwargs["api"] = {"dtype": "fp32"} + kwargs["check_func"] = check_func_fp32 + self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs) + + def test_save_as_external_data(self, request): + with tempfile.TemporaryDirectory() as tempdir: + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs["bash"] = "--save-as-external-data" + kwargs["api"] = {"save_as_external_data": True} + self.run_basic_test(FILENAME, out_model_path, **kwargs) + assert os.path.getsize(out_model_path) < 1e5 + + def test_model_check(self, request): + with tempfile.TemporaryDirectory() as tempdir: + out_model_path = os.path.join(tempdir, "resnet18.onnx") + input_data = os.path.join(tempdir, "input.npy") + test_data = np.random.random((1, 3, 224, 224)).astype(np.float32) + np.save(input_data, test_data) + kwargs = {} + kwargs["bash"] = f"--model-check --model-check-inputs input:{input_data}" + kwargs["api"] = {"model_check": True, "model_check_inputs": [f"input:{input_data}"]} + self.run_basic_test(FILENAME, out_model_path, **kwargs) + + def test_inspect(self, request): + with tempfile.TemporaryDirectory(): + kwargs = {} + kwargs["bash"] = "--inspect" + kwargs["api"] = {"inspect": True} + self.run_basic_test(FILENAME, FILENAME, **kwargs) if __name__ == "__main__":