diff --git a/onnxslim/core/slim.py b/onnxslim/core/slim.py index bc84f07..7789a87 100644 --- a/onnxslim/core/slim.py +++ b/onnxslim/core/slim.py @@ -1,8 +1,5 @@ -import logging import os -import sys import tempfile -from typing import Dict, List import numpy as np import onnx @@ -12,13 +9,7 @@ from onnxslim.core.optimizer import delete_node, optimize_model from onnxslim.core.symbolic_shape_infer import SymbolicShapeInference from onnxslim.onnx_graphsurgeon.ir.tensor import Constant -from onnxslim.utils import ( - dump_model_info_to_disk, - gen_onnxruntime_input_data, - logger, - onnxruntime_inference, - print_model_info_as_table, -) +from onnxslim.utils import logger, save DEBUG = bool(os.getenv("ONNXSLIM_DEBUG")) AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE"))) @@ -91,8 +82,8 @@ def shape_infer(model: onnx.ModelProto): try: logger.debug("try onnxruntime shape infer.") model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE) - except Exception: - logger.debug("onnxruntime shape infer failed, try onnx shape infer.") + except Exception as err: + logger.debug(f"onnxruntime shape infer failed, try onnx shape infer. {err}") if model.ByteSize() >= checker.MAXIMUM_PROTOBUF: tmp_dir = tempfile.TemporaryDirectory() tmp_path = os.path.join(tmp_dir.name, "tmp.onnx") diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index f733b72..c085ae7 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -1,6 +1,7 @@ import os import subprocess import warnings +import shutil import pytest import timm @@ -27,10 +28,11 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): """Test various TorchVision models with random input tensors of a specified shape.""" model = model(pretrained=PRETRAINED) x = torch.rand(shape) - os.makedirs(f"tmp/{request.node.name}", exist_ok=True) + directory = "tmp/" + request.node.name + os.makedirs(directory, exist_ok=True) - filename = f"tmp/{request.node.name}/{request.node.name}.onnx" - slim_filename = f"tmp/{request.node.name}/{request.node.name}_slim.onnx" + filename = f"{directory}/{request.node.name}.onnx" + slim_filename = f"{directory}/{request.node.name}_slim.onnx" torch.onnx.export(model, x, filename) @@ -41,7 +43,7 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): print(output) assert result.returncode == 0 - os.remove(filename) + shutil.rmtree(directory, ignore_errors=True) class TestTimmClass: @@ -55,12 +57,12 @@ def test_timm(self, request, model_name): model = timm.create_model(model_name, pretrained=PRETRAINED) input_size = model.default_cfg.get("input_size") x = torch.randn((1,) + input_size) - + directory = "tmp/" + request.node.name try: - os.makedirs(f"tmp/{request.node.name}", exist_ok=True) + os.makedirs(directory, exist_ok=True) - filename = f"tmp/{request.node.name}/{request.node.name}.onnx" - slim_filename = f"tmp/{request.node.name}/{request.node.name}_slim.onnx" + filename = f"{directory}/{request.node.name}.onnx" + slim_filename = f"{directory}/{request.node.name}_slim.onnx" torch.onnx.export(model, x, filename) except Exception as e: print(f"An unexpected error occurred: {str(e)}") @@ -75,7 +77,7 @@ def test_timm(self, request, model_name): print(output) assert result.returncode == 0 - os.remove(filename) + shutil.rmtree(directory, ignore_errors=True) if __name__ == "__main__":