Skip to content

Commit

Permalink
enable onnxslim test and fix bug (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis authored Nov 24, 2024
1 parent b833ec8 commit 543f1ad
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 32 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ jobs:
pip install ultralytics
pytest tests/test_yolo.py -sv
- name: onnxslim api and binary test
run: |
pip install onnxconverter_common
pytest tests/test_onnxslim.py
- name: model zoo test
run: |
python -m pip install --upgrade pip wheel setuptools
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _extract_info(self, tensor):
else:
shape.append(None)

self.shape = tuple(shape)
self.shape = tuple(shape) if shape is not None else None
self.name = tensor.name


Expand All @@ -361,8 +361,8 @@ def _extract_info(self, operator):
class ModelInfo:
def __init__(self, model: Union[str, onnx.ModelProto], tag: str = "OnnxSlim"):
if isinstance(model, str):
model = onnx.load(model)
tag = Path(model).name
model = onnx.load(model)

self.tag: str = tag
self.model_size: int = -1
Expand Down
61 changes: 31 additions & 30 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class TestFunctional:
def test_basic(self, request):
"""Test the basic functionality of the slim function."""
with tempfile.TemporaryDirectory() as tempdir:
summary = summarize_model(slim(FILENAME))
print_model_info_as_table(request.node.name, summary)
summary = summarize_model(slim(FILENAME), request.node.name)
print_model_info_as_table(summary)
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name)
slim(FILENAME, output_name, model_check=True)
Expand All @@ -32,16 +32,16 @@ def test_basic(self, request):
class TestFeature:
def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]))
print_model_info_as_table(request.node.name, summary)
assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224)
summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"])
summary = summarize_model(output_name)
print_model_info_as_table(request.node.name, summary)
assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224)
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)

def test_fp162fp32_conversion(self, request):
"""Test the conversion of an ONNX model from FP16 to FP32 precision."""
Expand All @@ -50,46 +50,47 @@ def test_fp162fp32_conversion(self, request):
with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16")
summary = summarize_model(output_name)
print_model_info_as_table(request.node.name, summary)
assert summary["op_input_info"]["input"][0] == np.float16
assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224)
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].dtype == np.float16
assert summary.input_info[0].shape == (1, 3, 224, 224)

slim(output_name, output_name, dtype="fp32")
summary = summarize_model(output_name)
print_model_info_as_table(request.node.name, summary)
assert summary["op_input_info"]["input"][0] == np.float32
assert summary["op_input_info"]["input"][1] == (1, 3, 224, 224)
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].dtype == np.float32
assert summary.input_info[0].shape == (1, 3, 224, 224)

def test_output_modification(self, request):
"""Tests output modification."""
summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]))
print_model_info_as_table(request.node.name, summary)
assert "/Flatten_output_0" in summary["op_output_info"]
summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, outputs=["/Flatten_output_0"])
summary = summarize_model(output_name)
print_model_info_as_table(request.node.name, summary)
assert "/Flatten_output_0" in summary["op_output_info"]
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps

def test_input_modification(self, request):
"""Tests input modification."""
summary = summarize_model(
slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])
slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]),
request.node.name,
)
print_model_info_as_table(request.node.name, summary)
assert "/maxpool/MaxPool_output_0" in summary["op_input_info"]
assert "/layer1/layer1.0/relu/Relu_output_0" in summary["op_input_info"]
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])
summary = summarize_model(output_name)
print_model_info_as_table(request.node.name, summary)
assert "/maxpool/MaxPool_output_0" in summary["op_input_info"]
assert "/layer1/layer1.0/relu/Relu_output_0" in summary["op_input_info"]
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps


if __name__ == "__main__":
Expand Down

0 comments on commit 543f1ad

Please sign in to comment.