diff --git a/onnxslim/core/__init__.py b/onnxslim/core/__init__.py index bd7871f..34dcfda 100644 --- a/onnxslim/core/__init__.py +++ b/onnxslim/core/__init__.py @@ -171,7 +171,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto: for node in graph.nodes: if node.op == "Cast": inp_dtype = [input.dtype for input in node.inputs][0] - if inp_dtype in {np.float16, np.float32}: + if inp_dtype in [np.float16, np.float32]: delete_node(node) for tensor in graph.tensors().values(): diff --git a/onnxslim/core/pattern/elimination/reshape.py b/onnxslim/core/pattern/elimination/reshape.py index c2684cb..617f12f 100644 --- a/onnxslim/core/pattern/elimination/reshape.py +++ b/onnxslim/core/pattern/elimination/reshape.py @@ -46,7 +46,8 @@ def check_constant_mergeable(reshape_node): reshape_shape = reshape_node.inputs[1].values.tolist() if input_shape is not None and np.any(reshape_shape == 0): shape = [ - input_shape[i] if dim_size == 0 else reshape_shape[i] for i, dim_size in enumerate(reshape_shape) + input_shape[i] if dim_size == 0 else reshape_shape[i] + for i, dim_size in enumerate(reshape_shape) ] if not all(isinstance(item, int) for item in shape): return False diff --git a/onnxslim/core/pattern/fusion/__init__.py b/onnxslim/core/pattern/fusion/__init__.py index 411c211..3179cf1 100644 --- a/onnxslim/core/pattern/fusion/__init__.py +++ b/onnxslim/core/pattern/fusion/__init__.py @@ -1,6 +1,6 @@ +from .convadd import * from .convbn import * from .gelu import * from .gemm import * from .padconv import * from .reduce import * -from .convadd import * diff --git a/onnxslim/core/pattern/fusion/convadd.py b/onnxslim/core/pattern/fusion/convadd.py index 94234ca..2c5bad0 100644 --- a/onnxslim/core/pattern/fusion/convadd.py +++ b/onnxslim/core/pattern/fusion/convadd.py @@ -1,4 +1,3 @@ -import numpy as np import onnxslim.third_party.onnx_graphsurgeon as gs from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users @@ -29,7 +28,13 @@ def rewrite(self, opset=11): conv_weight = list(conv_node.inputs)[1] conv_node_users = get_node_users(conv_node) node = self.add_0 - if len(conv_node_users) == 1 and isinstance(node.inputs[1], gs.Constant) and isinstance(conv_weight, gs.Constant) and node.inputs[1].values.squeeze().ndim == 1 and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]: + if ( + len(conv_node_users) == 1 + and isinstance(node.inputs[1], gs.Constant) + and isinstance(conv_weight, gs.Constant) + and node.inputs[1].values.squeeze().ndim == 1 + and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0] + ): add_node = node if len(conv_node.inputs) == 2: conv_bias = node.inputs[1].values.squeeze() diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 2da545c..76840b6 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -183,7 +183,12 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time ) else: final_op_info.append( - ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"] + " / IR Version: " + model_info_list[0]["ir_version"]] + [""] * (len(model_info_list) - 2) + [ + "Model Name", + model_name, + "Op Set: " + model_info_list[0]["op_set"] + " / IR Version: " + model_info_list[0]["ir_version"], + ] + + [""] * (len(model_info_list) - 2) ) final_op_info.extend( ( @@ -321,6 +326,7 @@ def get_opset(model: onnx.ModelProto) -> int: except Exception: return None + def get_ir_version(model: onnx.ModelProto) -> int: """Returns the ONNX ir version for a given model.""" try: @@ -328,6 +334,7 @@ def get_ir_version(model: onnx.ModelProto) -> int: except Exception: return None + def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict: """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" if isinstance(model, str):