diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 25d52c5..21ce684 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,11 @@ jobs: pip install ultralytics pytest tests/test_yolo.py -sv + - name: onnxslim api and binary test + run: | + pip install onnxconverter_common + pytest tests/test_onnxslim.py + - name: model zoo test run: | python -m pip install --upgrade pip wheel setuptools diff --git a/onnxslim/utils.py b/onnxslim/utils.py index f07487f..4b8c3e4 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -341,7 +341,7 @@ def _extract_info(self, tensor): else: shape.append(None) - self.shape = tuple(shape) + self.shape = tuple(shape) if shape is not None else None self.name = tensor.name @@ -361,8 +361,8 @@ def _extract_info(self, operator): class ModelInfo: def __init__(self, model: Union[str, onnx.ModelProto], tag: str = "OnnxSlim"): if isinstance(model, str): - model = onnx.load(model) tag = Path(model).name + model = onnx.load(model) self.tag: str = tag self.model_size: int = -1 diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index b36fee2..515f65e 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -15,8 +15,8 @@ class TestFunctional: def test_basic(self, request): """Test the basic functionality of the slim function.""" with tempfile.TemporaryDirectory() as tempdir: - summary = summarize_model(slim(FILENAME)) - print_model_info_as_table(request.node.name, summary) + summary = summarize_model(slim(FILENAME), request.node.name) + print_model_info_as_table(summary) output_name = os.path.join(tempdir, "resnet18.onnx") slim(FILENAME, output_name) slim(FILENAME, output_name, model_check=True) @@ -32,16 +32,16 @@ def test_basic(self, request): class TestFeature: def test_input_shape_modification(self, request): """Test the modification of input shapes.""" - summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"])) - print_model_info_as_table(request.node.name, summary) - assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224) + summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name) + print_model_info_as_table(summary) + assert summary.input_info[0].shape == (1, 3, 224, 224) with tempfile.TemporaryDirectory() as tempdir: output_name = os.path.join(tempdir, "resnet18.onnx") slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"]) - summary = summarize_model(output_name) - print_model_info_as_table(request.node.name, summary) - assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224) + summary = summarize_model(output_name, request.node.name) + print_model_info_as_table(summary) + assert summary.input_info[0].shape == (1, 3, 224, 224) def test_fp162fp32_conversion(self, request): """Test the conversion of an ONNX model from FP16 to FP32 precision.""" @@ -50,46 +50,47 @@ def test_fp162fp32_conversion(self, request): with tempfile.TemporaryDirectory() as tempdir: output_name = os.path.join(tempdir, "resnet18.onnx") slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16") - summary = summarize_model(output_name) - print_model_info_as_table(request.node.name, summary) - assert summary["op_input_info"]["input"][0] == np.float16 - assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224) + summary = summarize_model(output_name, request.node.name) + print_model_info_as_table(summary) + assert summary.input_info[0].dtype == np.float16 + assert summary.input_info[0].shape == (1, 3, 224, 224) slim(output_name, output_name, dtype="fp32") - summary = summarize_model(output_name) - print_model_info_as_table(request.node.name, summary) - assert summary["op_input_info"]["input"][0] == np.float32 - assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224) + summary = summarize_model(output_name, request.node.name) + print_model_info_as_table(summary) + assert summary.input_info[0].dtype == np.float32 + assert summary.input_info[0].shape == (1, 3, 224, 224) def test_output_modification(self, request): """Tests output modification.""" - summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"])) - print_model_info_as_table(request.node.name, summary) - assert "/Flatten_output_0" in summary["op_output_info"] + summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name) + print_model_info_as_table(summary) + assert "/Flatten_output_0" in summary.output_maps with tempfile.TemporaryDirectory() as tempdir: output_name = os.path.join(tempdir, "resnet18.onnx") slim(FILENAME, output_name, outputs=["/Flatten_output_0"]) - summary = summarize_model(output_name) - print_model_info_as_table(request.node.name, summary) - assert "/Flatten_output_0" in summary["op_output_info"] + summary = summarize_model(output_name, request.node.name) + print_model_info_as_table(summary) + assert "/Flatten_output_0" in summary.output_maps def test_input_modification(self, request): """Tests input modification.""" summary = summarize_model( - slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]) + slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]), + request.node.name, ) - print_model_info_as_table(request.node.name, summary) - assert "/maxpool/MaxPool_output_0" in summary["op_input_info"] - assert "/layer1/layer1.0/relu/Relu_output_0" in summary["op_input_info"] + print_model_info_as_table(summary) + assert "/maxpool/MaxPool_output_0" in summary.input_maps + assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps with tempfile.TemporaryDirectory() as tempdir: output_name = os.path.join(tempdir, "resnet18.onnx") slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]) - summary = summarize_model(output_name) - print_model_info_as_table(request.node.name, summary) - assert "/maxpool/MaxPool_output_0" in summary["op_input_info"] - assert "/layer1/layer1.0/relu/Relu_output_0" in summary["op_input_info"] + summary = summarize_model(output_name, request.node.name) + print_model_info_as_table(summary) + assert "/maxpool/MaxPool_output_0" in summary.input_maps + assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps if __name__ == "__main__":