diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index 6c79d30..ef5420f 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -62,7 +62,7 @@ def get_info(model, inspect=False): if not inspect: return model_name, model - model_info = summarize_model(model) + model_info = summarize_model(model, model_name) return model_name, model_info @@ -80,7 +80,7 @@ def get_info(model, inspect=False): else: model_name, model = get_info(model) if output_model: - original_info = summarize_model(model) + original_info = summarize_model(model, model_name) if inputs: model = input_modification(model, inputs) @@ -123,7 +123,7 @@ def get_info(model, inspect=False): if not output_model: return model - slimmed_info = summarize_model(model) + slimmed_info = summarize_model(model, output_model) save(model, output_model, model_check, save_as_external_data, slimmed_info) end_time = time.time() diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 824246c..cf905dc 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -1,3 +1,4 @@ +from pathlib import Path import logging import os import sys @@ -171,68 +172,49 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time model_info_list = [model_info_list] final_op_info = [] - if len(model_info_list) == 1: - final_op_info.extend( - ( - ["Model Name", model_name], - [SEPARATING_LINE], - ["Op Set ", model_info_list[0]["op_set"]], - [SEPARATING_LINE], - ["IR Version ", model_info_list[0]["ir_version"]], - ) - ) - else: - final_op_info.append( - [ - "Model Name", - model_name, - "Op Set: " + model_info_list[0]["op_set"] + " / IR Version: " + model_info_list[0]["ir_version"], - ] - + [""] * (len(model_info_list) - 2) - ) final_op_info.extend( ( + ["Model Name"] + [item.tag for item in model_info_list], [SEPARATING_LINE] * (len(model_info_list) + 1), - ["Model Info"] - + [model_info_list[0].get("tag", "Original Model")] - + [item.get("tag", "Slimmed Model") for item in model_info_list[1:]], - [SEPARATING_LINE] * (len(model_info_list) + 1), + ["Model Info"] + [ "Op Set: " + item.op_set + " / IR Version: " + item.ir_version for item in model_info_list], + [SEPARATING_LINE] * (len(model_info_list) + 1) ) ) - all_inputs = [op_type for model_info in model_info_list for op_type in model_info.get("op_input_info", {})] - all_inputs = list(dict.fromkeys(all_inputs)) - - for inputs in all_inputs: - input_info_list = [f"IN: {inputs}"] - for model_info in model_info_list: - inputs_shape = model_info["op_input_info"].get(inputs, "") - if isinstance(inputs_shape, (list, tuple)): - inputs_shape = ": ".join([str(i) for i in inputs_shape]) - input_info_list.append(inputs_shape) - final_op_info.append(input_info_list) - - all_outputs = [op_type for model_info in model_info_list for op_type in model_info.get("op_output_info", {})] - all_outputs = list(dict.fromkeys(all_outputs)) - - for outputs in all_outputs: - output_info_list = [f"OUT: {outputs}"] - for model_info in model_info_list: - outputs_shape = model_info["op_output_info"].get(outputs, "") - if isinstance(outputs_shape, (list, tuple)): - outputs_shape = ": ".join([str(i) for i in outputs_shape]) - output_info_list.append(outputs_shape) - final_op_info.append(output_info_list) + def get_io_info(model_info_list, tag=None): + if tag == "OUT": + ios = [op_type for model_info in model_info_list for op_type in model_info.output_info] + else: + ios = [op_type for model_info in model_info_list for op_type in model_info.input_info] + ios = list(dict.fromkeys([io.name for io in ios])) + io_info = [] + for io in ios: + input_info_list = [f"{tag}: {io}"] + for model_info in model_info_list: + if tag == "OUT": + io_tensor = model_info.output_maps.get(io, None) + else: + io_tensor = model_info.input_maps.get(io, None) + inputs_shape = (io_tensor.dtype, io_tensor.shape) if io_tensor else "" + if isinstance(inputs_shape, (list, tuple)): + inputs_shape = ": ".join([str(i) for i in inputs_shape]) + input_info_list.append(inputs_shape) + io_info.append(input_info_list) + + return io_info + + final_op_info.extend(get_io_info(model_info_list, "IN")) + final_op_info.extend(get_io_info(model_info_list, "OUT")) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) - all_ops = {op_type for model_info in model_info_list for op_type in model_info.get("op_type_counts", {})} + all_ops = {op_type for model_info in model_info_list for op_type in model_info.op_type_counts} sorted_ops = sorted(all_ops) for op in sorted_ops: op_info_list = [op] - float_number = model_info_list[0]["op_type_counts"].get(op, 0) + float_number = model_info_list[0].op_type_counts.get(op, 0) op_info_list.append(float_number) for model_info in model_info_list[1:]: - slimmed_number = model_info["op_type_counts"].get(op, 0) + slimmed_number = model_info.op_type_counts.get(op, 0) if float_number > slimmed_number: slimmed_number = GREEN + str(slimmed_number) + WHITE op_info_list.append(slimmed_number) @@ -241,7 +223,7 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time final_op_info.extend( ( [SEPARATING_LINE] * (len(model_info_list) + 1), - ["Model Size"] + [format_bytes(model_info["model_size"]) for model_info in model_info_list], + ["Model Size"] + [format_bytes(model_info.model_size) for model_info in model_info_list], ) ) if elapsed_time: @@ -334,26 +316,16 @@ def get_ir_version(model: onnx.ModelProto) -> int: except Exception: return None - -def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict: - """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" - if isinstance(model, str): - model = onnx.load(model) - - logger.debug("Start summarizing model.") - model_info = {} - if tag is not None: - model_info["tag"] = tag - - model_size = model.ByteSize() - model_info["model_size"] = model_size - - op_info = {} - op_type_counts = defaultdict(int) - - def get_tensor_dtype_shape(tensor): +class TensorInfo(object): + def __init__(self, tensor): + self.dtype: np.dtype = np.float32 + self.shape: Tuple[Union[str, int]] = None + + self._extract_info(tensor) + + def _extract_info(self, tensor): """Extract the data type and shape of an ONNX tensor.""" - type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown") + self.dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown") shape = None if tensor.type.tensor_type.HasField("shape"): shape = [] @@ -365,51 +337,89 @@ def get_tensor_dtype_shape(tensor): else: shape.append(None) - return (type_str, shape) - - def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]: - op_shape_info = {} - for input in inputs: - type_str, shape = get_tensor_dtype_shape(input) - if shape: - op_shape_info[input.name] = (type_str, tuple(shape)) - else: - op_shape_info[input.name] = (type_str, None) - - return op_shape_info - - value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info} - - def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]: - for node in graph.node: - op_type = node.op_type - op_type_counts[op_type] += 1 - for output in node.output: - shapes = [] - if output in value_info_dict: - tensor = value_info_dict[output] - type_str, shape = get_tensor_dtype_shape(tensor) - shapes.append([type_str, shape]) - - op_info[node.name] = [node.op_type, shapes] - - for attr in node.attribute: - ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()} - if attr.type in ATTR_TYPE_MAPPING: - attr_str = ATTR_TYPE_MAPPING[attr.type] - if attr_str == "GRAPH": - get_graph_node_info(attr.g) - - get_graph_node_info(model.graph) - - model_info["op_set"] = str(get_opset(model)) - model_info["ir_version"] = str(get_ir_version(model)) - model_info["op_info"] = op_info - model_info["op_type_counts"] = op_type_counts - - model_info["op_input_info"] = get_shape(model.graph.input) - model_info["op_output_info"] = get_shape(model.graph.output) + self.shape = tuple(shape) + self.name = tensor.name + +class OperatorInfo(object): + def __init__(self, operator, outputs=None): + self.name: str = None + self.op: str = None + + self._extract_info(operator) + self.outputs = outputs + + def _extract_info(self, operator): + self.name: str = operator.name + self.op: str = operator.op_type + +class ModelInfo(object): + def __init__(self, model: Union[str, onnx.ModelProto], tag: str="OnnxSlim"): + if isinstance(model, str): + model = onnx.load(model) + tag = Path(model).name + + self.tag: str = tag + self.model_size: int = -1 + self.op_set: str = None + self.ir_version: str = None + self.op_type_counts: Dict[str, int] = defaultdict(int) + self.op_info: Dict[str, Dict] = {} + self.input_info: List[str, Tuple[str, Tuple]] = [] + self.output_info: List[str, Tuple[str, Tuple]] = [] + + self._summarize_model(model) + + def _summarize_model(self, model): + self.op_set = str(get_opset(model)) + self.ir_version = str(get_ir_version(model)) + self.model_size = model.ByteSize() + + for input in model.graph.input: + self.input_info.append(TensorInfo(input)) + + for output in model.graph.output: + self.output_info.append(TensorInfo(output)) + + value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info} + + def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]: + for node in graph.node: + op_type = node.op_type + self.op_type_counts[op_type] += 1 + output_tensor_info = [] + for output in node.output: + if output in value_info_dict: + tensor = value_info_dict[output] + tensor_info = TensorInfo(tensor) + output_tensor_info.append(tensor_info) + + self.op_info[node.name] = OperatorInfo(node, output_tensor_info) + + for attr in node.attribute: + ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()} + if attr.type in ATTR_TYPE_MAPPING: + attr_str = ATTR_TYPE_MAPPING[attr.type] + if attr_str == "GRAPH": + get_graph_node_info(attr.g) + + get_graph_node_info(model.graph) + + @property + def input_maps(self): + self.input_dict = {input_info.name: input_info for input_info in self.input_info} + + return self.input_dict + + @property + def output_maps(self): + self.output_dict = {output_info.name: output_info for output_info in self.output_info} + + return self.output_dict +def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict: + """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" + logger.debug("Start summarizing model.") + model_info = ModelInfo(model, tag) logger.debug("Finish summarizing model.") return model_info