diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 38b22e3..ea14c16 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -166,7 +166,7 @@ def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> Dict[str, return onnx_output, model -def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time: float = None): +def format_model_info(model_info_list: List[Dict], elapsed_time: float = None): assert model_info_list, "model_info_list must contain more than one model info" if not isinstance(model_info_list, (list, tuple)): model_info_list = [model_info_list] @@ -239,9 +239,12 @@ def get_io_info(model_info_list, tag=None): return final_op_info -def print_model_info_as_table(model_name: str, model_info_list: List[Dict], elapsed_time: float = None): +def print_model_info_as_table(model_info_list: List[Dict], elapsed_time: float = None): """Prints the model information as a formatted table for the given model name and list of model details.""" - final_op_info = format_model_info(model_name, model_info_list, elapsed_time) + if not isinstance(model_info_list, (list, tuple)): + model_info_list = [model_info_list] + + final_op_info = format_model_info(model_info_list, elapsed_time) lines = tabulate( final_op_info, headers=[], @@ -260,13 +263,11 @@ def print_model_info_as_table(model_name: str, model_info_list: List[Dict], elap print(output) -def dump_model_info_to_disk(model_name: str, model_info: Dict): +def dump_model_info_to_disk(model_info: Dict): """Writes model information to a CSV file for a given model name and dictionary of model info.""" import csv - import os - filename_without_extension, _ = os.path.splitext(os.path.basename(model_name)) - csv_file_path = f"{filename_without_extension}_model_info.csv" + csv_file_path = f"{model_info.tag}_model_info.csv" with open(csv_file_path, "a", newline="") as csvfile: # Use 'a' for append mode fieldnames = ["NodeName", "OpType", "OutputDtype", "OutputShape"] writer = csv.DictWriter(csvfile, fieldnames=fieldnames) @@ -422,7 +423,7 @@ def output_maps(self): return self.output_dict -def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict: +def summarize_model(model: Union[str, onnx.ModelProto], tag="OnnxModel") -> Dict: """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" logger.debug("Start summarizing model.") model_info = ModelInfo(model, tag) diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index 858b604..3b6cd32 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -49,16 +49,16 @@ def test_tiny_en_decoder(self, request): def test_transformer_encoder(self, request): name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" - summary = summarize_model(slim(filename)) - print_model_info_as_table(request.node.name, summary) + summary = summarize_model(slim(filename), tag=request.node.name) + print_model_info_as_table(summary) assert summary.op_type_counts["Mul"] == 57 assert summary.op_type_counts["Div"] == 53 def test_uiex(self, request): name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" - summary = summarize_model(slim(filename)) - print_model_info_as_table(request.node.name, summary) + summary = summarize_model(slim(filename), tag=request.node.name) + print_model_info_as_table(summary) assert summary.op_type_counts["Range"] == 0 assert summary.op_type_counts["Floor"] == 0 assert summary.op_type_counts["Concat"] == 54