Skip to content

Commit

Permalink
remove modelname arg
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 22, 2024
1 parent 5e7c242 commit 38e6af6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
17 changes: 9 additions & 8 deletions onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> Dict[str,
return onnx_output, model


def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time: float = None):
def format_model_info(model_info_list: List[Dict], elapsed_time: float = None):
assert model_info_list, "model_info_list must contain more than one model info"
if not isinstance(model_info_list, (list, tuple)):
model_info_list = [model_info_list]
Expand Down Expand Up @@ -239,9 +239,12 @@ def get_io_info(model_info_list, tag=None):
return final_op_info


def print_model_info_as_table(model_name: str, model_info_list: List[Dict], elapsed_time: float = None):
def print_model_info_as_table(model_info_list: List[Dict], elapsed_time: float = None):
"""Prints the model information as a formatted table for the given model name and list of model details."""
final_op_info = format_model_info(model_name, model_info_list, elapsed_time)
if not isinstance(model_info_list, (list, tuple)):
model_info_list = [model_info_list]

final_op_info = format_model_info(model_info_list, elapsed_time)
lines = tabulate(
final_op_info,
headers=[],
Expand All @@ -260,13 +263,11 @@ def print_model_info_as_table(model_name: str, model_info_list: List[Dict], elap
print(output)


def dump_model_info_to_disk(model_name: str, model_info: Dict):
def dump_model_info_to_disk(model_info: Dict):
"""Writes model information to a CSV file for a given model name and dictionary of model info."""
import csv
import os

filename_without_extension, _ = os.path.splitext(os.path.basename(model_name))
csv_file_path = f"{filename_without_extension}_model_info.csv"
csv_file_path = f"{model_info.tag}_model_info.csv"
with open(csv_file_path, "a", newline="") as csvfile: # Use 'a' for append mode
fieldnames = ["NodeName", "OpType", "OutputDtype", "OutputShape"]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
Expand Down Expand Up @@ -422,7 +423,7 @@ def output_maps(self):
return self.output_dict


def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict:
def summarize_model(model: Union[str, onnx.ModelProto], tag="OnnxModel") -> Dict:
"""Generates a summary of the ONNX model, including model size, operations, and tensor shapes."""
logger.debug("Start summarizing model.")
model_info = ModelInfo(model, tag)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def test_tiny_en_decoder(self, request):
def test_transformer_encoder(self, request):
name = request.node.originalname[len("test_") :]
filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, summary)
summary = summarize_model(slim(filename), tag=request.node.name)
print_model_info_as_table(summary)
assert summary.op_type_counts["Mul"] == 57
assert summary.op_type_counts["Div"] == 53

def test_uiex(self, request):
name = request.node.originalname[len("test_") :]
filename = f"{MODELZOO_PATH}/{name}/{name}.onnx"
summary = summarize_model(slim(filename))
print_model_info_as_table(request.node.name, summary)
summary = summarize_model(slim(filename), tag=request.node.name)
print_model_info_as_table(summary)
assert summary.op_type_counts["Range"] == 0
assert summary.op_type_counts["Floor"] == 0
assert summary.op_type_counts["Concat"] == 54
Expand Down

0 comments on commit 38e6af6

Please sign in to comment.