Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 27, 2024
1 parent 03fadff commit 1685e15
Showing 1 changed file with 69 additions and 73 deletions.
142 changes: 69 additions & 73 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,113 +3,109 @@
import tempfile

import pytest

import numpy as np
from onnxslim import slim
from onnxslim.utils import print_model_info_as_table, summarize_model
from onnxslim.utils import summarize_model

MODELZOO_PATH = "/data/modelzoo"
FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx"


class TestFunctional:
def __test_command_basic(self, request, in_model_path, out_model_path, arg_str=""):
summary = summarize_model(slim(in_model_path), request.node.name)
print_model_info_as_table(summary)
slim(in_model_path, out_model_path)
slim(in_model_path, out_model_path, model_check=True)
command = f'onnxslim "{in_model_path}" "{out_model_path}" {arg_str}'
def run_basic_test(self, in_model_path, out_model_path, **kwargs):
check_func = kwargs.get('check_func', None)
kwargs_api = kwargs.get('api', {})
kwargs_bash = kwargs.get('bash', "")
summary = summarize_model(slim(in_model_path, **kwargs_api), in_model_path)
if check_func: check_func(summary)

slim(in_model_path, out_model_path, **kwargs_api)
summary = summarize_model(out_model_path, out_model_path)
if check_func: check_func(summary)

command = f'onnxslim "{in_model_path}" "{out_model_path}", {kwargs_bash}'

result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

summary = summarize_model(out_model_path, out_model_path)
if check_func: check_func(summary)

def test_basic(self, request):
"""Test the basic functionality of the slim function."""
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
self.__test_command_basic(request, FILENAME, out_model_path, "")
self.run_basic_test(FILENAME, out_model_path)

def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
input_shape_arg_str = "--input-shapes input:1,3,224,224"
self.__test_command_basic(request, FILENAME, out_model_path, input_shape_arg_str)
def check_func(summary):
assert summary.input_info[0].shape == (1, 3, 224, 224)

def test_fp322fp16_conversion(self, request):
"""Test the conversion of an ONNX model from FP32 to FP16 precision."""
with tempfile.TemporaryDirectory() as tempdir:
out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
dtype_fp16_arg_str = "--dtype fp16"
dtype_fp32_arg_str = "--dtype fp32"
self.__test_command_basic(request, FILENAME, out_fp16_model_path, dtype_fp16_arg_str)
self.__test_command_basic(request, out_fp16_model_path, out_fp32_model_path, dtype_fp32_arg_str)

out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs['bash'] = "--input-shapes input:1,3,224,224"
kwargs['api'] = {'input_shapes': ["input:1,3,224,224"]}
kwargs['check_func'] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

class TestFeature:
def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"]), request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)
def test_input_modification(self, request):
def check_func(summary):
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert summary.input_info[0].shape == (1, 3, 224, 224)
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs['bash'] = "--inputs /maxpool/MaxPool_output_0 /layer1/layer1.0/relu/Relu_output_0"
kwargs['api'] = {"inputs": ["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]}
kwargs['check_func'] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_fp162fp32_conversion(self, request):
"""Test the conversion of an ONNX model from FP16 to FP32 precision."""
import numpy as np
def test_output_modification(self, request):
def check_func(summary):
assert "/Flatten_output_0" in summary.output_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, input_shapes=["input:1,3,224,224"], dtype="fp16")
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs['bash'] = "--outputs /Flatten_output_0"
kwargs['api'] = {"outputs": ["/Flatten_output_0"]}
kwargs['check_func'] = check_func
self.run_basic_test(FILENAME, out_model_path, **kwargs)

def test_dtype_conversion(self, request):
def check_func_fp16(summary):
assert summary.input_info[0].dtype == np.float16
assert summary.input_info[0].shape == (1, 3, 224, 224)

slim(output_name, output_name, dtype="fp32")
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
def check_func_fp32(summary):
assert summary.input_info[0].dtype == np.float32
assert summary.input_info[0].shape == (1, 3, 224, 224)

def test_output_modification(self, request):
"""Tests output modification."""
summary = summarize_model(slim(FILENAME, outputs=["/Flatten_output_0"]), request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps

with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, outputs=["/Flatten_output_0"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/Flatten_output_0" in summary.output_maps
out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
kwargs = {}
kwargs['bash'] = "--dtype fp16"
kwargs['api'] = {"dtype": "fp16"}
kwargs['check_func'] = check_func_fp16
self.run_basic_test(FILENAME, out_fp16_model_path, **kwargs)

def test_input_modification(self, request):
"""Tests input modification."""
summary = summarize_model(
slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]),
request.node.name,
)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
kwargs = {}
kwargs['bash'] = "--dtype fp32"
kwargs['api'] = {"dtype": "fp32"}
kwargs['check_func'] = check_func_fp32
self.run_basic_test(out_fp16_model_path, out_fp32_model_path, **kwargs)

def test_save_as_external_data(self, request):
with tempfile.TemporaryDirectory() as tempdir:
output_name = os.path.join(tempdir, "resnet18.onnx")
slim(FILENAME, output_name, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])
summary = summarize_model(output_name, request.node.name)
print_model_info_as_table(summary)
assert "/maxpool/MaxPool_output_0" in summary.input_maps
assert "/layer1/layer1.0/relu/Relu_output_0" in summary.input_maps
out_model_path = os.path.join(tempdir, "resnet18.onnx")
kwargs = {}
kwargs['bash'] = "--save-as-external-data"
kwargs['api'] = {"save_as_external_data": True}
self.run_basic_test(FILENAME, out_model_path, **kwargs)
assert os.path.getsize(out_model_path) < 1e5


if __name__ == "__main__":
Expand Down

0 comments on commit 1685e15

Please sign in to comment.