Skip to content

Commit

Permalink
change __test_command_basic function
Browse files Browse the repository at this point in the history
  • Loading branch information
whyb committed Nov 27, 2024
1 parent b7f4605 commit 7e53d40
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions tests/test_onnxslim.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,41 @@
FILENAME = f"{MODELZOO_PATH}/resnet18/resnet18.onnx"


class TestFunctional:
def __test_command_basic(self, request, in_model_path=FILENAME, out_model_name="resnet18.onnx", arg_str=""):
with tempfile.TemporaryDirectory() as tempdir:
summary = summarize_model(slim(in_model_path), request.node.name)
print_model_info_as_table(summary)
out_model_path = os.path.join(tempdir, out_model_name)
slim(in_model_path, out_model_path)
slim(in_model_path, out_model_path, model_check=True)
command = f'onnxslim "{in_model_path}" "{out_model_path}" {arg_str}'
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0
return in_model_path, out_model_path
class TestFunctional:
def __test_command_basic(self, request, in_model_path, out_model_path, arg_str=""):
summary = summarize_model(slim(in_model_path), request.node.name)
print_model_info_as_table(summary)
slim(in_model_path, out_model_path)
slim(in_model_path, out_model_path, model_check=True)
command = f'onnxslim "{in_model_path}" "{out_model_path}" {arg_str}'
result = subprocess.run(command, shell=True, capture_output=True, text=True)
output = result.stderr.strip()
# Assert the expected return code
print(output)
assert result.returncode == 0

def test_basic(self, request):
"""Test the basic functionality of the slim function."""
self.__test_command_basic(request)
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
self.__test_command_basic(request, FILENAME, out_model_path, "")

def test_input_shape_modification(self, request):
"""Test the modification of input shapes."""
input_shape_arg_str = "--input-shapes input:1,3,224,224"
self.__test_command_basic(request, FILENAME, "resnet18.onnx", input_shape_arg_str)
with tempfile.TemporaryDirectory() as tempdir:
out_model_path = os.path.join(tempdir, "resnet18.onnx")
input_shape_arg_str = "--input-shapes input:1,3,224,224"
self.__test_command_basic(request, FILENAME, out_model_path, input_shape_arg_str)

def test_fp322fp16_conversion(self, request):
"""Test the conversion of an ONNX model from FP32 to FP16 precision."""
dtype_fp16_arg_str = "--dtype fp16"
_, out_model_path = self.__test_command_basic(request, FILENAME, "resnet18_fp16.onnx", dtype_fp16_arg_str)
dtype_fp32_arg_str = "--dtype fp32"
self.__test_command_basic(request, out_model_path, "resnet18_fp32.onnx", dtype_fp32_arg_str)
with tempfile.TemporaryDirectory() as tempdir:
out_fp16_model_path = os.path.join(tempdir, "resnet18_fp16.onnx")
out_fp32_model_path = os.path.join(tempdir, "resnet18_fp32.onnx")
dtype_fp16_arg_str = "--dtype fp16"
dtype_fp32_arg_str = "--dtype fp32"
self.__test_command_basic(request, FILENAME, out_fp16_model_path, dtype_fp16_arg_str)
self.__test_command_basic(request, out_fp16_model_path, out_fp32_model_path, dtype_fp32_arg_str)


class TestFeature:
Expand Down

0 comments on commit 7e53d40

Please sign in to comment.