From 1685e1535fb07bc4e8ffea662b92bb7787a7bfa2 Mon Sep 17 00:00:00 2001 From: inisis Date: Wed, 27 Nov 2024 20:22:03 +0800 Subject: [PATCH] refactor tests --- tests/test_onnxslim.py | 142 ++++++++++++++++++++--------------------- 1 file changed, 69 insertions(+), 73 deletions(-) diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index 34afd1b..67bc405 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -3,113 +3,109 @@ import tempfile import pytest - +import numpy as np from onnxslim import slim -from onnxslim.utils import print_model_info_as_table, summarize_model +from onnxslim.utils import summarize_model MODELZOO_PATH = "/data/modelzoo" FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx" class TestFunctional: - def __test_command_basic(self, request, in_model_path, out_model_path, arg_str=""): - summary = summarize_model(slim(in_model_path), request.node.name) - print_model_info_as_table(summary) - slim(in_model_path, out_model_path) - slim(in_model_path, out_model_path, model_check=True) - command = f'onnxslim "{in_model_path}" "{out_model_path}" {arg_str}' + def run_basic_test(self, in_model_path, out_model_path, **kwargs): + check_func = kwargs.get('check_func', None) + kwargs_api = kwargs.get('api', {}) + kwargs_bash = kwargs.get('bash', "") + summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path) + if check_func: check_func(summary) + + slim(in_model_path, out_model_path, **kwargs_api) + summary = summarize_model(out_model_path, out_model_path) + if check_func: check_func(summary) + + command = f'onnxslim "{in_model_path}" "{out_model_path}", {kwargs_bash}' + result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() # Assert the expected return code print(output) assert result.returncode == 0 + summary = summarize_model(out_model_path, out_model_path) + if check_func: check_func(summary) + def test_basic(self, request): - """Test the basic functionality of the slim function.""" with tempfile.TemporaryDirectory() as tempdir: out_model_path = os.path.join(tempdir, "resnet18.onnx") - self.__test_command_basic(request, FILENAME, out_model_path, "") + self.run_basic_test(FILENAME, out_model_path) def test_input_shape_modification(self, request): - """Test the modification of input shapes.""" - with tempfile.TemporaryDirectory() as tempdir: - out_model_path = os.path.join(tempdir, "resnet18.onnx") - input_shape_arg_str = "--input-shapes input:1,3,224,224" - self.__test_command_basic(request, FILENAME, out_model_path, input_shape_arg_str) + def check_func(summary): + assert summary.input_info[0].shape == (1, 3, 224, 224) - def test_fp322fp16_conversion(self, request): - """Test the conversion of an ONNX model from FP32 to FP16 precision.""" with tempfile.TemporaryDirectory() as tempdir: - out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx") - out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx") - dtype_fp16_arg_str = "--dtype fp16" - dtype_fp32_arg_str = "--dtype fp32" - self.__test_command_basic(request, FILENAME, out_fp16_model_path, dtype_fp16_arg_str) - self.__test_command_basic(request, out_fp16_model_path, out_fp32_model_path, dtype_fp32_arg_str) - + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs['bash'] = "--input-shapes input:1,3,224,224" + kwargs['api'] = {'input_shapes': ["input:1,3,224,224"]} + kwargs['check_func'] = check_func + self.run_basic_test(FILENAME, out_model_path, **kwargs) -class TestFeature: - def test_input_shape_modification(self, request): - """Test the modification of input shapes.""" - summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name) - print_model_info_as_table(summary) - assert summary.input_info[0].shape == (1, 3, 224, 224) + def test_input_modification(self, request): + def check_func(summary): + assert "/maxpool/MaxPool_output_0" in summary.input_maps + assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"]) - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert summary.input_info[0].shape == (1, 3, 224, 224) + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs['bash'] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0" + kwargs['api'] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]} + kwargs['check_func'] = check_func + self.run_basic_test(FILENAME, out_model_path, **kwargs) - def test_fp162fp32_conversion(self, request): - """Test the conversion of an ONNX model from FP16 to FP32 precision.""" - import numpy as np + def test_output_modification(self, request): + def check_func(summary): + assert "/Flatten_output_0" in summary.output_maps with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16") - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs['bash'] = "--outputs /Flatten_output_0" + kwargs['api'] = {"outputs": ["/Flatten_output_0"]} + kwargs['check_func'] = check_func + self.run_basic_test(FILENAME, out_model_path, **kwargs) + + def test_dtype_conversion(self, request): + def check_func_fp16(summary): assert summary.input_info[0].dtype == np.float16 - assert summary.input_info[0].shape == (1, 3, 224, 224) - slim(output_name, output_name, dtype="fp32") - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) + def check_func_fp32(summary): assert summary.input_info[0].dtype == np.float32 - assert summary.input_info[0].shape == (1, 3, 224, 224) - - def test_output_modification(self, request): - """Tests output modification.""" - summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name) - print_model_info_as_table(summary) - assert "/Flatten_output_0" in summary.output_maps with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, outputs=["/Flatten_output_0"]) - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert "/Flatten_output_0" in summary.output_maps + out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx") + kwargs = {} + kwargs['bash'] = "--dtype fp16" + kwargs['api'] = {"dtype": "fp16"} + kwargs['check_func'] = check_func_fp16 + self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs) - def test_input_modification(self, request): - """Tests input modification.""" - summary = summarize_model( - slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]), - request.node.name, - ) - print_model_info_as_table(summary) - assert "/maxpool/MaxPool_output_0" in summary.input_maps - assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps + out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx") + kwargs = {} + kwargs['bash'] = "--dtype fp32" + kwargs['api'] = {"dtype": "fp32"} + kwargs['check_func'] = check_func_fp32 + self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs) + def test_save_as_external_data(self, request): with tempfile.TemporaryDirectory() as tempdir: - output_name = os.path.join(tempdir, "resnet18.onnx") - slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]) - summary = summarize_model(output_name, request.node.name) - print_model_info_as_table(summary) - assert "/maxpool/MaxPool_output_0" in summary.input_maps - assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps + out_model_path = os.path.join(tempdir, "resnet18.onnx") + kwargs = {} + kwargs['bash'] = "--save-as-external-data" + kwargs['api'] = {"save_as_external_data": True} + self.run_basic_test(FILENAME, out_model_path, **kwargs) + assert os.path.getsize(out_model_path) < 1e5 if __name__ == "__main__":