Skip to content

Commit

Permalink
refactor summary
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 21, 2024
1 parent da589e9 commit 149ac2c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 116 deletions.
6 changes: 3 additions & 3 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_info(model, inspect=False):
if not inspect:
return model_name, model

model_info = summarize_model(model)
model_info = summarize_model(model, model_name)

return model_name, model_info

Expand All @@ -80,7 +80,7 @@ def get_info(model, inspect=False):
else:
model_name, model = get_info(model)
if output_model:
original_info = summarize_model(model)
original_info = summarize_model(model, model_name)

if inputs:
model = input_modification(model, inputs)
Expand Down Expand Up @@ -123,7 +123,7 @@ def get_info(model, inspect=False):
if not output_model:
return model

slimmed_info = summarize_model(model)
slimmed_info = summarize_model(model, output_model)
save(model, output_model, model_check, save_as_external_data, slimmed_info)

end_time = time.time()
Expand Down
236 changes: 123 additions & 113 deletions onnxslim/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
import logging
import os
import sys
Expand Down Expand Up @@ -171,68 +172,49 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time
model_info_list = [model_info_list]

final_op_info = []
if len(model_info_list) == 1:
final_op_info.extend(
(
["Model Name", model_name],
[SEPARATING_LINE],
["Op Set ", model_info_list[0]["op_set"]],
[SEPARATING_LINE],
["IR Version ", model_info_list[0]["ir_version"]],
)
)
else:
final_op_info.append(
[
"Model Name",
model_name,
"Op Set: " + model_info_list[0]["op_set"] + " / IR Version: " + model_info_list[0]["ir_version"],
]
+ [""] * (len(model_info_list) - 2)
)
final_op_info.extend(
(
["Model Name"] + [item.tag for item in model_info_list],
[SEPARATING_LINE] * (len(model_info_list) + 1),
["Model Info"]
+ [model_info_list[0].get("tag", "Original Model")]
+ [item.get("tag", "Slimmed Model") for item in model_info_list[1:]],
[SEPARATING_LINE] * (len(model_info_list) + 1),
["Model Info"] + [ "Op Set: " + item.op_set + " / IR Version: " + item.ir_version for item in model_info_list],
[SEPARATING_LINE] * (len(model_info_list) + 1)
)
)
all_inputs = [op_type for model_info in model_info_list for op_type in model_info.get("op_input_info", {})]
all_inputs = list(dict.fromkeys(all_inputs))

for inputs in all_inputs:
input_info_list = [f"IN: {inputs}"]
for model_info in model_info_list:
inputs_shape = model_info["op_input_info"].get(inputs, "")
if isinstance(inputs_shape, (list, tuple)):
inputs_shape = ": ".join([str(i) for i in inputs_shape])
input_info_list.append(inputs_shape)
final_op_info.append(input_info_list)

all_outputs = [op_type for model_info in model_info_list for op_type in model_info.get("op_output_info", {})]
all_outputs = list(dict.fromkeys(all_outputs))

for outputs in all_outputs:
output_info_list = [f"OUT: {outputs}"]
for model_info in model_info_list:
outputs_shape = model_info["op_output_info"].get(outputs, "")
if isinstance(outputs_shape, (list, tuple)):
outputs_shape = ": ".join([str(i) for i in outputs_shape])
output_info_list.append(outputs_shape)
final_op_info.append(output_info_list)
def get_io_info(model_info_list, tag=None):
if tag == "OUT":
ios = [op_type for model_info in model_info_list for op_type in model_info.output_info]
else:
ios = [op_type for model_info in model_info_list for op_type in model_info.input_info]
ios = list(dict.fromkeys([io.name for io in ios]))
io_info = []
for io in ios:
input_info_list = [f"{tag}: {io}"]
for model_info in model_info_list:
if tag == "OUT":
io_tensor = model_info.output_maps.get(io, None)
else:
io_tensor = model_info.input_maps.get(io, None)
inputs_shape = (io_tensor.dtype, io_tensor.shape) if io_tensor else ""
if isinstance(inputs_shape, (list, tuple)):
inputs_shape = ": ".join([str(i) for i in inputs_shape])
input_info_list.append(inputs_shape)
io_info.append(input_info_list)

return io_info

final_op_info.extend(get_io_info(model_info_list, "IN"))
final_op_info.extend(get_io_info(model_info_list, "OUT"))

final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1))

all_ops = {op_type for model_info in model_info_list for op_type in model_info.get("op_type_counts", {})}
all_ops = {op_type for model_info in model_info_list for op_type in model_info.op_type_counts}
sorted_ops = sorted(all_ops)
for op in sorted_ops:
op_info_list = [op]
float_number = model_info_list[0]["op_type_counts"].get(op, 0)
float_number = model_info_list[0].op_type_counts.get(op, 0)
op_info_list.append(float_number)
for model_info in model_info_list[1:]:
slimmed_number = model_info["op_type_counts"].get(op, 0)
slimmed_number = model_info.op_type_counts.get(op, 0)
if float_number > slimmed_number:
slimmed_number = GREEN + str(slimmed_number) + WHITE
op_info_list.append(slimmed_number)
Expand All @@ -241,7 +223,7 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time
final_op_info.extend(
(
[SEPARATING_LINE] * (len(model_info_list) + 1),
["Model Size"] + [format_bytes(model_info["model_size"]) for model_info in model_info_list],
["Model Size"] + [format_bytes(model_info.model_size) for model_info in model_info_list],
)
)
if elapsed_time:
Expand Down Expand Up @@ -334,26 +316,16 @@ def get_ir_version(model: onnx.ModelProto) -> int:
except Exception:
return None


def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict:
"""Generates a summary of the ONNX model, including model size, operations, and tensor shapes."""
if isinstance(model, str):
model = onnx.load(model)

logger.debug("Start summarizing model.")
model_info = {}
if tag is not None:
model_info["tag"] = tag

model_size = model.ByteSize()
model_info["model_size"] = model_size

op_info = {}
op_type_counts = defaultdict(int)

def get_tensor_dtype_shape(tensor):
class TensorInfo(object):
def __init__(self, tensor):
self.dtype: np.dtype = np.float32
self.shape: Tuple[Union[str, int]] = None

self._extract_info(tensor)

def _extract_info(self, tensor):
"""Extract the data type and shape of an ONNX tensor."""
type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown")
self.dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown")
shape = None
if tensor.type.tensor_type.HasField("shape"):
shape = []
Expand All @@ -365,51 +337,89 @@ def get_tensor_dtype_shape(tensor):
else:
shape.append(None)

return (type_str, shape)

def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]:
op_shape_info = {}
for input in inputs:
type_str, shape = get_tensor_dtype_shape(input)
if shape:
op_shape_info[input.name] = (type_str, tuple(shape))
else:
op_shape_info[input.name] = (type_str, None)

return op_shape_info

value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info}

def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]:
for node in graph.node:
op_type = node.op_type
op_type_counts[op_type] += 1
for output in node.output:
shapes = []
if output in value_info_dict:
tensor = value_info_dict[output]
type_str, shape = get_tensor_dtype_shape(tensor)
shapes.append([type_str, shape])

op_info[node.name] = [node.op_type, shapes]

for attr in node.attribute:
ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
if attr.type in ATTR_TYPE_MAPPING:
attr_str = ATTR_TYPE_MAPPING[attr.type]
if attr_str == "GRAPH":
get_graph_node_info(attr.g)

get_graph_node_info(model.graph)

model_info["op_set"] = str(get_opset(model))
model_info["ir_version"] = str(get_ir_version(model))
model_info["op_info"] = op_info
model_info["op_type_counts"] = op_type_counts

model_info["op_input_info"] = get_shape(model.graph.input)
model_info["op_output_info"] = get_shape(model.graph.output)
self.shape = tuple(shape)
self.name = tensor.name

class OperatorInfo(object):
def __init__(self, operator, outputs=None):
self.name: str = None
self.op: str = None

self._extract_info(operator)
self.outputs = outputs

def _extract_info(self, operator):
self.name: str = operator.name
self.op: str = operator.op_type

class ModelInfo(object):
def __init__(self, model: Union[str, onnx.ModelProto], tag: str="OnnxSlim"):
if isinstance(model, str):
model = onnx.load(model)
tag = Path(model).name

self.tag: str = tag
self.model_size: int = -1
self.op_set: str = None
self.ir_version: str = None
self.op_type_counts: Dict[str, int] = defaultdict(int)
self.op_info: Dict[str, Dict] = {}
self.input_info: List[str, Tuple[str, Tuple]] = []
self.output_info: List[str, Tuple[str, Tuple]] = []

self._summarize_model(model)

def _summarize_model(self, model):
self.op_set = str(get_opset(model))
self.ir_version = str(get_ir_version(model))
self.model_size = model.ByteSize()

for input in model.graph.input:
self.input_info.append(TensorInfo(input))

for output in model.graph.output:
self.output_info.append(TensorInfo(output))

value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info}

def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]:
for node in graph.node:
op_type = node.op_type
self.op_type_counts[op_type] += 1
output_tensor_info = []
for output in node.output:
if output in value_info_dict:
tensor = value_info_dict[output]
tensor_info = TensorInfo(tensor)
output_tensor_info.append(tensor_info)

self.op_info[node.name] = OperatorInfo(node, output_tensor_info)

for attr in node.attribute:
ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
if attr.type in ATTR_TYPE_MAPPING:
attr_str = ATTR_TYPE_MAPPING[attr.type]
if attr_str == "GRAPH":
get_graph_node_info(attr.g)

get_graph_node_info(model.graph)

@property
def input_maps(self):
self.input_dict = {input_info.name: input_info for input_info in self.input_info}

return self.input_dict

@property
def output_maps(self):
self.output_dict = {output_info.name: output_info for output_info in self.output_info}

return self.output_dict

def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict:
"""Generates a summary of the ONNX model, including model size, operations, and tensor shapes."""
logger.debug("Start summarizing model.")
model_info = ModelInfo(model, tag)
logger.debug("Finish summarizing model.")
return model_info

Expand Down

0 comments on commit 149ac2c

Please sign in to comment.