From 016b5d60fc6a17eed8b9f41271d9c383e80bb8df Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 7 Sep 2024 18:54:51 +0200 Subject: [PATCH] Revert "Docstrings and code refactor (#22)" This reverts commit 942d5c32d804f180d6833785acd5cef14f54771d. --- .../cse_demo.py | 4 +- onnxslim/argparser.py | 5 -- onnxslim/cli/_main.py | 31 ++++--- onnxslim/core/optimization/weight_tying.py | 11 ++- onnxslim/core/pattern/__init__.py | 8 -- onnxslim/core/pattern/elimination/reshape.py | 2 - onnxslim/core/pattern/elimination/slice.py | 2 - .../core/pattern/elimination/unsqueeze.py | 89 +++++++++---------- onnxslim/core/pattern/fusion/convbn.py | 2 - onnxslim/core/pattern/fusion/gelu.py | 2 - onnxslim/core/pattern/fusion/gemm.py | 2 - onnxslim/core/pattern/fusion/padconv.py | 2 - onnxslim/core/pattern/fusion/reduce.py | 2 - onnxslim/misc/tabulate.py | 4 +- .../exporters/base_exporter.py | 4 +- .../exporters/onnx_exporter.py | 13 ++- .../graph_pattern/graph_pattern.py | 11 +-- .../importers/base_importer.py | 4 +- .../importers/onnx_importer.py | 14 +-- .../third_party/onnx_graphsurgeon/ir/graph.py | 82 +++++++++-------- .../third_party/onnx_graphsurgeon/ir/node.py | 16 ++-- .../onnx_graphsurgeon/ir/tensor.py | 18 ++-- .../onnx_graphsurgeon/logger/logger.py | 18 ++-- .../onnx_graphsurgeon/util/misc.py | 2 - onnxslim/third_party/symbolic_shape_infer.py | 23 +++-- onnxslim/utils.py | 3 +- setup.py | 4 +- tests/test_modelzoo.py | 7 -- tests/test_onnx_nets.py | 2 - tests/test_onnxslim.py | 8 +- tests/test_pattern_generator.py | 6 +- tests/test_pattern_matcher.py | 20 ++--- tests/utils.py | 6 +- 33 files changed, 188 insertions(+), 239 deletions(-) diff --git a/examples/common_subexpression_elimination/cse_demo.py b/examples/common_subexpression_elimination/cse_demo.py index 3805f00..0e97a5f 100644 --- a/examples/common_subexpression_elimination/cse_demo.py +++ b/examples/common_subexpression_elimination/cse_demo.py @@ -4,11 +4,9 @@ class Model(torch.nn.Module): - """A PyTorch model applying LayerNorm to input tensors for normalization in neural network layers.""" - def __init__(self): """Initializes the Model class with a single LayerNorm layer of embedding dimension 10.""" - super().__init__() + super(Model, self).__init__() embedding_dim = 10 self.layer_norm = nn.LayerNorm(embedding_dim) diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index 59d6a4b..b5a3a1a 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -110,10 +110,7 @@ class CheckerArguments: class ArgumentParser: - """Parses command-line arguments into specified dataclasses for ONNX model optimization and modification tasks.""" - def __init__(self, *argument_dataclasses: Type): - """Initializes the ArgumentParser with dataclass types for parsing ONNX model optimization arguments.""" self.argument_dataclasses = argument_dataclasses self.parser = argparse.ArgumentParser( description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model", @@ -122,7 +119,6 @@ def __init__(self, *argument_dataclasses: Type): self._add_arguments() def _add_arguments(self): - """Adds command-line arguments to the parser based on provided dataclass fields and their metadata.""" for dataclass_type in self.argument_dataclasses: for field_name, field_def in dataclass_type.__dataclass_fields__.items(): arg_type = field_def.type @@ -154,7 +150,6 @@ def _add_arguments(self): self.parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__) def parse_args_into_dataclasses(self): - """Parses command-line arguments into specified dataclass instances for structured configuration.""" args = self.parser.parse_args() args_dict = vars(args) diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index c39ecfd..4c9c3d7 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -4,7 +4,6 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): - """Slims an ONNX model by optimizing and modifying its structure, inputs, and outputs for improved performance.""" import os import time from pathlib import Path @@ -12,9 +11,9 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): from onnxslim.core import ( convert_data_format, freeze, - input_modification, input_shape_modification, optimize, + input_modification, output_modification, shape_infer, ) @@ -30,20 +29,20 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): summarize_model, ) - output_model = args[0] if args else kwargs.get("output_model", None) - model_check = kwargs.get("model_check", False) - input_shapes = kwargs.get("input_shapes", None) - inputs = kwargs.get("inputs", None) - outputs = kwargs.get("outputs", None) - no_shape_infer = kwargs.get("no_shape_infer", False) - no_constant_folding = kwargs.get("no_constant_folding", False) - dtype = kwargs.get("dtype", None) - skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None) - inspect = kwargs.get("inspect", False) - dump_to_disk = kwargs.get("dump_to_disk", False) - save_as_external_data = kwargs.get("save_as_external_data", False) - model_check_inputs = kwargs.get("model_check_inputs", None) - verbose = kwargs.get("verbose", False) + output_model = args[0] if len(args) > 0 else kwargs.get('output_model', None) + model_check = kwargs.get('model_check', False) + input_shapes = kwargs.get('input_shapes', None) + inputs = kwargs.get('inputs', None) + outputs = kwargs.get('outputs', None) + no_shape_infer = kwargs.get('no_shape_infer', False) + no_constant_folding = kwargs.get('no_constant_folding', False) + dtype = kwargs.get('dtype', None) + skip_fusion_patterns = kwargs.get('skip_fusion_patterns', None) + inspect = kwargs.get('inspect', False) + dump_to_disk = kwargs.get('dump_to_disk', False) + save_as_external_data = kwargs.get('save_as_external_data', False) + model_check_inputs = kwargs.get('model_check_inputs', None) + verbose = kwargs.get('verbose', False) logger = init_logging(verbose) diff --git a/onnxslim/core/optimization/weight_tying.py b/onnxslim/core/optimization/weight_tying.py index 24aed85..20c8f1a 100644 --- a/onnxslim/core/optimization/weight_tying.py +++ b/onnxslim/core/optimization/weight_tying.py @@ -31,7 +31,10 @@ def replace_constant_references(existing_constant, to_be_removed_constant): for i, constant_tensor in enumerate(constant_tensors): if keep_constants[i]: for j in range(i + 1, len(constant_tensors)): - if keep_constants[j] and constant_tensor == constant_tensors[j]: - keep_constants[j] = False - replace_constant_references(constant_tensor, constant_tensors[j]) - logger.debug(f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}") + if keep_constants[j]: + if constant_tensor == constant_tensors[j]: + keep_constants[j] = False + replace_constant_references(constant_tensor, constant_tensors[j]) + logger.debug( + f"Constant {constant_tensors[j].name} can be replaced by {constant_tensor.name}" + ) diff --git a/onnxslim/core/pattern/__init__.py b/onnxslim/core/pattern/__init__.py index fa9991e..32bf0d5 100644 --- a/onnxslim/core/pattern/__init__.py +++ b/onnxslim/core/pattern/__init__.py @@ -44,8 +44,6 @@ def get_name(name): class NodeDescriptor: - """Represents a node in a computational graph, detailing its operation type, inputs, and outputs.""" - def __init__(self, node_spec): """Initialize NodeDescriptor with node_spec list requiring at least 4 elements.""" if not isinstance(node_spec, list): @@ -89,8 +87,6 @@ def __dict__(self): class Pattern: - """Parses and matches ONNX graph patterns into NodeDescriptor objects for model optimization tasks.""" - def __init__(self, pattern): """Initialize the Pattern class with a given pattern and parse its nodes.""" self.pattern = pattern @@ -113,8 +109,6 @@ def __repr__(self): class PatternMatcher: - """Matches computational graph nodes to predefined patterns for optimization and transformation tasks.""" - def __init__(self, pattern, priority): """Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output names. @@ -190,8 +184,6 @@ def parameter_check(self): class PatternGenerator: - """Generates pattern templates from an ONNX model by processing its graph structure and node connections.""" - def __init__(self, onnx_model): """Initialize the PatternGenerator class with an ONNX model and process its graph.""" self.graph = gs.import_onnx(onnx_model) diff --git a/onnxslim/core/pattern/elimination/reshape.py b/onnxslim/core/pattern/elimination/reshape.py index 20bb1d6..90eeeae 100644 --- a/onnxslim/core/pattern/elimination/reshape.py +++ b/onnxslim/core/pattern/elimination/reshape.py @@ -6,8 +6,6 @@ class ReshapePatternMatcher(PatternMatcher): - """Matches and optimizes nested reshape operations in computational graphs to eliminate redundancy.""" - def __init__(self, priority): """Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape operations. diff --git a/onnxslim/core/pattern/elimination/slice.py b/onnxslim/core/pattern/elimination/slice.py index e303cad..14bec9c 100644 --- a/onnxslim/core/pattern/elimination/slice.py +++ b/onnxslim/core/pattern/elimination/slice.py @@ -6,8 +6,6 @@ class SlicePatternMatcher(PatternMatcher): - """Matches and optimizes nested slice operations in ONNX graphs to improve computational efficiency.""" - def __init__(self, priority): """Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/elimination/unsqueeze.py b/onnxslim/core/pattern/elimination/unsqueeze.py index e69962d..c4fcf8d 100644 --- a/onnxslim/core/pattern/elimination/unsqueeze.py +++ b/onnxslim/core/pattern/elimination/unsqueeze.py @@ -6,8 +6,6 @@ class UnsqueezePatternMatcher(PatternMatcher): - """Matches and optimizes nested unsqueeze patterns in ONNX graphs to improve computational efficiency.""" - def __init__(self, priority): """Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern.""" pattern = Pattern( @@ -31,60 +29,53 @@ def rewrite(self, opset=11): node_unsqueeze_0 = self.unsqueeze_0 users_node_unsqueeze_0 = get_node_users(node_unsqueeze_0) node_unsqueeze_1 = self.unsqueeze_1 - if ( - len(users_node_unsqueeze_0) == 1 - and node_unsqueeze_0.inputs[0].shape - and node_unsqueeze_1.inputs[0].shape - and ( - opset < 13 - or ( - isinstance(node_unsqueeze_0.inputs[1], gs.Constant) - and isinstance(node_unsqueeze_1.inputs[1], gs.Constant) - ) - ) - ): + if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape: + if opset < 13 or ( + isinstance(node_unsqueeze_0.inputs[1], gs.Constant) + and isinstance(node_unsqueeze_1.inputs[1], gs.Constant) + ): - def get_unsqueeze_axes(unsqueeze_node, opset): - dim = len(unsqueeze_node.inputs[0].shape) - if opset < 13: - axes = unsqueeze_node.attrs["axes"] - else: - axes = unsqueeze_node.inputs[1].values - return [axis + dim + len(axes) if axis < 0 else axis for axis in axes] + def get_unsqueeze_axes(unsqueeze_node, opset): + dim = len(unsqueeze_node.inputs[0].shape) + if opset < 13: + axes = unsqueeze_node.attrs["axes"] + else: + axes = unsqueeze_node.inputs[1].values + return [axis + dim + len(axes) if axis < 0 else axis for axis in axes] - axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset) - axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset) + axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset) + axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset) - axes_node_unsqueeze_0 = [ - axis + sum(bool(axis_ <= axis) for axis_ in axes_node_unsqueeze_1) for axis in axes_node_unsqueeze_0 - ] + axes_node_unsqueeze_0 = [ + axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0 + ] - inputs = [node_unsqueeze_0.inputs[0]] - outputs = list(node_unsqueeze_1.outputs) - node_unsqueeze_0.inputs.clear() - node_unsqueeze_0.outputs.clear() - node_unsqueeze_1.inputs.clear() - node_unsqueeze_1.outputs.clear() + inputs = [node_unsqueeze_0.inputs[0]] + outputs = list(node_unsqueeze_1.outputs) + node_unsqueeze_0.inputs.clear() + node_unsqueeze_0.outputs.clear() + node_unsqueeze_1.inputs.clear() + node_unsqueeze_1.outputs.clear() - if opset < 13: - attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1} - else: - attrs = None - inputs.append( - gs.Constant( - name=f"{node_unsqueeze_0.name}_axes", - values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64), + if opset < 13: + attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1} + else: + attrs = None + inputs.append( + gs.Constant( + name=f"{node_unsqueeze_0.name}_axes", + values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64), + ) ) - ) - match_case[node_unsqueeze_0.name] = { - "op": "Unsqueeze", - "inputs": inputs, - "outputs": outputs, - "name": node_unsqueeze_0.name, - "attrs": attrs, - "domain": None, - } + match_case[node_unsqueeze_0.name] = { + "op": "Unsqueeze", + "inputs": inputs, + "outputs": outputs, + "name": node_unsqueeze_0.name, + "attrs": attrs, + "domain": None, + } return match_case diff --git a/onnxslim/core/pattern/fusion/convbn.py b/onnxslim/core/pattern/fusion/convbn.py index d6bba3a..f4aa0e0 100644 --- a/onnxslim/core/pattern/fusion/convbn.py +++ b/onnxslim/core/pattern/fusion/convbn.py @@ -6,8 +6,6 @@ class ConvBatchNormMatcher(PatternMatcher): - """Fuses Conv and BatchNormalization layers in an ONNX graph to optimize model performance and inference speed.""" - def __init__(self, priority): """Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/gelu.py b/onnxslim/core/pattern/fusion/gelu.py index 2157941..5efc87e 100644 --- a/onnxslim/core/pattern/fusion/gelu.py +++ b/onnxslim/core/pattern/fusion/gelu.py @@ -2,8 +2,6 @@ class GeluPatternMatcher(PatternMatcher): - """Matches and fuses GELU patterns in computational graphs for optimization purposes.""" - def __init__(self, priority): """Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/gemm.py b/onnxslim/core/pattern/fusion/gemm.py index f6eec15..9a80912 100644 --- a/onnxslim/core/pattern/fusion/gemm.py +++ b/onnxslim/core/pattern/fusion/gemm.py @@ -5,8 +5,6 @@ class MatMulAddPatternMatcher(PatternMatcher): - """Matches and fuses MatMul and Add operations in ONNX graphs to optimize computational efficiency.""" - def __init__(self, priority): """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/padconv.py b/onnxslim/core/pattern/fusion/padconv.py index 344a768..ef304cc 100644 --- a/onnxslim/core/pattern/fusion/padconv.py +++ b/onnxslim/core/pattern/fusion/padconv.py @@ -4,8 +4,6 @@ class PadConvMatcher(PatternMatcher): - """Matches and optimizes Pad-Conv patterns in ONNX graphs by ensuring padding parameters are constants.""" - def __init__(self, priority): """Initializes the PadConvMatcher with a specified priority and defines its matching pattern.""" pattern = Pattern( diff --git a/onnxslim/core/pattern/fusion/reduce.py b/onnxslim/core/pattern/fusion/reduce.py index 9b4d606..29f31d0 100644 --- a/onnxslim/core/pattern/fusion/reduce.py +++ b/onnxslim/core/pattern/fusion/reduce.py @@ -3,8 +3,6 @@ class ReducePatternMatcher(PatternMatcher): - """Optimizes ONNX graph patterns with ReduceSum and Unsqueeze operations for improved model performance.""" - def __init__(self, priority): """Initializes the ReducePatternMatcher with a specified pattern matching priority level.""" pattern = Pattern( diff --git a/onnxslim/misc/tabulate.py b/onnxslim/misc/tabulate.py index 514f385..55aa9fc 100644 --- a/onnxslim/misc/tabulate.py +++ b/onnxslim/misc/tabulate.py @@ -222,7 +222,7 @@ def make_header_line(is_header, colwidths, colaligns): alignment = {"left": "<", "right": ">", "center": "^", "decimal": ">"} # use the column widths generated by tabulate for the asciidoc column width specifiers asciidoc_alignments = zip(colwidths, [alignment[colalign] for colalign in colaligns]) - asciidoc_column_specifiers = [f"{width:d}{align}" for width, align in asciidoc_alignments] + asciidoc_column_specifiers = ["{:d}{}".format(width, align) for width, align in asciidoc_alignments] header_list = ['cols="' + (",".join(asciidoc_column_specifiers)) + '"'] # generate the list of options (currently only "header") @@ -2484,7 +2484,7 @@ def _wrap_chunks(self, chunks): """ lines = [] if self.width <= 0: - raise ValueError(f"invalid width {self.width!r} (must be > 0)") + raise ValueError("invalid width %r (must be > 0)" % self.width) if self.max_lines is not None: indent = self.subsequent_indent if self.max_lines > 1 else self.initial_indent if self._len(indent) + self._len(self.placeholder.lstrip()) > self.width: diff --git a/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py b/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py index 7b75664..17b7563 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +++ b/onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py @@ -18,9 +18,7 @@ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph -class BaseExporter: - """BaseExporter provides a static method to export ONNX graphs to a specified destination format.""" - +class BaseExporter(object): @staticmethod def export_graph(graph: Graph): """ diff --git a/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py index 3591a71..20c4ffb 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -52,7 +52,12 @@ def check_duplicate_node_names(nodes: Sequence[Node], level=G_LOGGER.WARNING): if not node.name: continue if node.name in name_map: - msg = f"Found distinct Nodes that share the same name:\n[id: {id(name_map[node.name])}]:\n {name_map[node.name]}---\n[id: {id(node)}]:\n {node}\n" + msg = "Found distinct Nodes that share the same name:\n[id: {:}]:\n {:}---\n[id: {:}]:\n {:}\n".format( + id(name_map[node.name]), + name_map[node.name], + id(node), + node, + ) G_LOGGER.log(msg, level) else: name_map[node.name] = node @@ -105,8 +110,6 @@ def np_float32_to_bf16_as_uint16(arr): class OnnxExporter(BaseExporter): - """Exports internal graph structures to ONNX format for model interoperability.""" - @staticmethod def export_tensor_proto(tensor: Constant) -> onnx.TensorProto: # Do *not* load LazyValues into an intermediate numpy array - instead, use @@ -143,7 +146,9 @@ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueIn """Creates an ONNX ValueInfoProto from a Tensor, optionally checking for dtype information.""" if do_type_check and tensor.dtype is None: G_LOGGER.critical( - f"Graph input and output tensors must include dtype information. Please set the dtype attribute for: {tensor}" + "Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format( + tensor + ) ) if tensor.dtype is None: diff --git a/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py b/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py index 414b53c..32e11ec 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +++ b/onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py @@ -104,7 +104,7 @@ def __init__(self) -> None: self.op = None # op (str) self.check_func = None # callback function for single node # pattern node name -> GraphPattern nodes(single or subpattern) - self.nodes: Dict[str, GraphPattern] = {} + self.nodes: Dict[str, "GraphPattern"] = {} # pattern node name -> input tensors self.node_inputs: Dict[str, List[int]] = {} # pattern node name -> output tensors @@ -119,7 +119,6 @@ def __init__(self) -> None: """Assigns a unique tensor ID, tracks its input node if provided, and initializes output node tracking.""" def _add_tensor(self, input_node=None) -> int: - """Assigns a unique tensor ID, tracks its input node if provided, and initializes output node tracking.""" tensor_id = self.num_tensors self.tensor_inputs[tensor_id] = [] if input_node is not None: @@ -240,13 +239,15 @@ def _single_node_match(self, onnx_node: Node) -> bool: with G_LOGGER.indent(): if self.op != onnx_node.op: G_LOGGER.info( - f"No match because: Op did not match. Node op was: {onnx_node.op} but pattern op was: {self.op}." + "No match because: Op did not match. Node op was: {:} but pattern op was: {:}.".format( + onnx_node.op, self.op + ) ) return False if self.check_func is not None and not self.check_func(onnx_node): G_LOGGER.info("No match because: check_func returned false.") return False - G_LOGGER.info(f"Single node is matched: {self.op}, {onnx_node.name}") + G_LOGGER.info("Single node is matched: {:}, {:}".format(self.op, onnx_node.name)) return True def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool): @@ -329,7 +330,7 @@ def _match_node( ) -> bool: """Matches ONNX nodes to the graph pattern starting from a specific node and tensor context.""" with G_LOGGER.indent(): - G_LOGGER.info(f"Checking node: {onnx_node.name} against pattern node: {node_name}.") + G_LOGGER.info("Checking node: {:} against pattern node: {:}.".format(onnx_node.name, node_name)) tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound) subgraph_mapping = self.nodes[node_name].match( onnx_node, diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py index 2b9b59c..50279ac 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py @@ -18,9 +18,7 @@ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph -class BaseImporter: - """BaseImporter provides functionality to import and convert source graphs into onnx-graphsurgeon Graph objects.""" - +class BaseImporter(object): @staticmethod def import_graph(graph) -> Graph: """ diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py index 3d634ac..9cbb15e 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py @@ -189,8 +189,6 @@ def get_onnx_tensor_type(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProt class OnnxImporter(BaseImporter): - """Imports ONNX models, functions, and tensors into internal representations for further processing.""" - @staticmethod def get_opset(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): """Return the ONNX opset version for the given ONNX model or function, or None if the information is @@ -279,10 +277,14 @@ def process_attr(attr_str: str): if attr_str in ONNX_PYTHON_ATTR_MAPPING: attr_dict[attr.name] = process_attr(attr_str) else: - G_LOGGER.warning(f"Attribute of type {attr_str} is currently unsupported. Skipping attribute.") + G_LOGGER.warning( + "Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str) + ) else: G_LOGGER.warning( - f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute." + "Attribute type: {:} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute.".format( + attr.type + ) ) return attr_dict @@ -313,7 +315,9 @@ def get_tensor(name: str, check_outer_graph=True): return Variable.empty() G_LOGGER.verbose( - f"Tensor: {name} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor." + "Tensor: {:} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor.".format( + name + ) ) subgraph_tensor_map[name] = Variable(name) return subgraph_tensor_map[name] diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py index 68ca230..355a031 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py @@ -28,9 +28,7 @@ from onnxslim.third_party.onnx_graphsurgeon.util import misc -class NodeIDAdder: - """Assigns unique IDs to graph nodes on entry and removes them on exit for context management.""" - +class NodeIDAdder(object): def __init__(self, graph): """Initializes NodeIDAdder with a specified graph.""" self.graph = graph @@ -47,7 +45,7 @@ def __exit__(self, exc_type, exc_value, traceback): del node.id -class Graph: +class Graph(object): """Represents a graph containing nodes and tensors.""" DEFAULT_OPSET = 11 @@ -85,8 +83,8 @@ def register_func(func): """ if hasattr(Graph, func.__name__): G_LOGGER.warning( - f"Registered function: {func.__name__} is hidden by a Graph attribute or function with the same name. " - "This function will never be called!" + "Registered function: {:} is hidden by a Graph attribute or function with the same name. " + "This function will never be called!".format(func.__name__) ) # Default behavior is to register functions for all opsets. @@ -145,7 +143,7 @@ def __init__( self._merge_subgraph_functions() # Printing graphs can be very expensive - G_LOGGER.ultra_verbose(lambda: f"Created Graph: {self}") + G_LOGGER.ultra_verbose(lambda: "Created Graph: {:}".format(self)) def __getattr__(self, name): """Dynamically handles attribute access, falling back to superclass attribute retrieval if not found.""" @@ -250,8 +248,8 @@ def _get_node_id(self, node): return node.id except AttributeError: G_LOGGER.critical( - f"Encountered a node not in the graph:\n{node}.\n\n" - "To fix this, please append the node to this graph's `nodes` attribute." + "Encountered a node not in the graph:\n{:}.\n\n" + "To fix this, please append the node to this graph's `nodes` attribute.".format(node) ) # A tensor is local if it is produced in this graph, or is explicitly a graph input. @@ -292,7 +290,7 @@ def _get_used_node_ids(self): """Returns a dictionary of tensors that are used by node IDs in the current subgraph.""" local_tensors = self._local_tensors() - class IgnoreDupAndForeign: + class IgnoreDupAndForeign(object): def __init__(self, initial_tensors=None): """Initialize IgnoreDupAndForeign with an optional list of initial tensors.""" tensors = misc.default_value(initial_tensors, []) @@ -423,7 +421,7 @@ def cleanup_subgraphs(): recurse_functions=False, # No infinite recursion ) - G_LOGGER.verbose(f"Cleaning up {self.name}") + G_LOGGER.verbose("Cleaning up {:}".format(self.name)) with self.node_ids(): # Graph input producers must be removed first so used_node_ids is correct. @@ -437,7 +435,7 @@ def cleanup_subgraphs(): if inp in used_tensors or not remove_unused_graph_inputs: inputs.append(inp) else: - G_LOGGER.ultra_verbose(f"Removing unused input: {inp}") + G_LOGGER.ultra_verbose("Removing unused input: {:}".format(inp)) self.inputs = inputs nodes = [] @@ -448,7 +446,7 @@ def cleanup_subgraphs(): else: node.inputs.clear() node.outputs.clear() - G_LOGGER.ultra_verbose(f"Removing unused node: {node}") + G_LOGGER.ultra_verbose("Removing unused node: {:}".format(node)) # Remove any hanging tensors - tensors without outputs if remove_unused_node_outputs: @@ -516,11 +514,11 @@ def toposort( for subgraph in self.subgraphs(): subgraph.toposort(recurse_subgraphs=True, recurse_functions=False, mode="nodes") - G_LOGGER.debug(f"Topologically sorting {self.name}") + G_LOGGER.debug("Topologically sorting {:}".format(self.name)) # Keeps track of a node and its level in the graph hierarchy. # 0 corresponds to an input node, N corresponds to a node with N layers of inputs. - class HierarchyDescriptor: + class HierarchyDescriptor(object): def __init__(self, node_or_func, level=None): """Initializes a HierarchyDescriptor with a node or function and an optional level in the graph hierarchy. @@ -640,8 +638,18 @@ def add_to_tensor_map(tensor): """Add a tensor to the tensor_map if it is not empty and ensure no duplicate tensor names exist.""" if not tensor.is_empty(): if tensor.name in tensor_map and tensor_map[tensor.name] is not tensor: - msg = f"Found distinct tensors that share the same name:\n[id: {id(tensor_map[tensor.name])}] {tensor_map[tensor.name]}\n[id: {id(tensor)}] {tensor}\n" - msg += f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}" + msg = "Found distinct tensors that share the same name:\n[id: {:}] {:}\n[id: {:}] {:}\n".format( + id(tensor_map[tensor.name]), + tensor_map[tensor.name], + id(tensor), + tensor, + ) + msg += ( + "Note: Producer node(s) of first tensor:\n{:}\nProducer node(s) of second tensor:\n{:}".format( + tensor_map[tensor.name].inputs, + tensor.inputs, + ) + ) if check_duplicates: G_LOGGER.critical(msg) @@ -748,10 +756,10 @@ def should_exclude_node(node): PARTITIONING_MODES = [None, "basic", "recursive"] if partitioning not in PARTITIONING_MODES: - G_LOGGER.critical(f"Argument for parameter 'partitioning' must be one of: {PARTITIONING_MODES}") + G_LOGGER.critical("Argument for parameter 'partitioning' must be one of: {:}".format(PARTITIONING_MODES)) ORT_PROVIDERS = ["CPUExecutionProvider"] - G_LOGGER.debug(f"Folding constants in {self.name}") + G_LOGGER.debug("Folding constants in {:}".format(self.name)) # We apply constant folding in 5 passes: # Pass 1 lowers 'Constant' nodes into Constant tensors. @@ -884,7 +892,7 @@ def run_cast_elision(node): if fold_shapes: # Perform shape tensor cast elision prior to most other folding - G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}") + G_LOGGER.debug("Performing shape tensor cast elision in {:}".format(self.name)) try: with self.node_ids(): for node in self.nodes: @@ -1069,13 +1077,13 @@ def fold_shape_slice(tensor): shape_of = shape_fold_func(tensor) if shape_of is not None: - G_LOGGER.ultra_verbose(f"Folding shape tensor: {tensor.name} to: {shape_of}") + G_LOGGER.ultra_verbose("Folding shape tensor: {:} to: {:}".format(tensor.name, shape_of)) graph_constants[tensor.name] = tensor.to_constant(shape_of) graph_constants[tensor.name].inputs.clear() except Exception as err: if not error_ok: raise err - G_LOGGER.warning(f"'{shape_fold_func.__name__}' routine failed with:\n{err}") + G_LOGGER.warning("'{:}' routine failed with:\n{:}".format(shape_fold_func.__name__, err)) else: graph_constants = update_foldable_outputs(graph_constants) @@ -1104,7 +1112,7 @@ def get_out_node_ids(): part = subgraph.copy() out_node = part.nodes[index] part.outputs = out_node.outputs - part.name = f"Folding: {[out.name for out in part.outputs]}" + part.name = "Folding: {:}".format([out.name for out in part.outputs]) part.cleanup(remove_unused_graph_inputs=True) names = [out.name for out in part.outputs] @@ -1118,7 +1126,7 @@ def get_out_node_ids(): ) values = sess.run(names, {}) except Exception as err: - G_LOGGER.warning(f"Inference failed for subgraph: {part.name}. Note: Error was:\n{err}") + G_LOGGER.warning("Inference failed for subgraph: {:}. Note: Error was:\n{:}".format(part.name, err)) if partitioning == "recursive": G_LOGGER.verbose("Attempting to recursively partition subgraph") # Partition failed, peel off last node. @@ -1160,7 +1168,7 @@ def should_eval_foldable(tensor): return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)] - G_LOGGER.debug(f"Folding tensors: {graph_clone.outputs}") + G_LOGGER.debug("Folding tensors: {:}".format(graph_clone.outputs)) graph_clone.cleanup(remove_unused_graph_inputs=True, recurse_functions=False) # Using ._values avoids a deep copy of the values. @@ -1206,16 +1214,16 @@ def should_eval_foldable(tensor): except Exception as err: G_LOGGER.warning( "Inference failed. You may want to try enabling partitioning to see better results. " - f"Note: Error was:\n{err}" + "Note: Error was:\n{:}".format(err) ) - G_LOGGER.verbose(f"Note: Graph was:\n{graph_clone}") + G_LOGGER.verbose("Note: Graph was:\n{:}".format(graph_clone)) if not error_ok: raise elif not constant_values: G_LOGGER.debug( - f"Could not find any nodes in this graph ({self.name}) that can be folded. " + "Could not find any nodes in this graph ({:}) that can be folded. " "This could mean that constant folding has already been run on this graph. " - "Skipping." + "Skipping.".format(self.name) ) # Finally, replace the Variables in the original graph with constants. @@ -1230,7 +1238,9 @@ def should_eval_foldable(tensor): if size_threshold is not None and values.nbytes > size_threshold: G_LOGGER.debug( - f"Will not fold: '{name}' since its size in bytes ({values.nbytes}) exceeds the size threshold ({size_threshold})" + "Will not fold: '{:}' since its size in bytes ({:}) exceeds the size threshold ({:})".format( + name, values.nbytes, size_threshold + ) ) continue elif size_threshold is None and values.nbytes > (1 << 20): @@ -1241,12 +1251,12 @@ def should_eval_foldable(tensor): if large_tensors: large_tensors_mib = { - tensor_name: f"{value // (1 << 20)} MiB" for tensor_name, value in large_tensors.items() + tensor_name: "{:} MiB".format(value // (1 << 20)) for tensor_name, value in large_tensors.items() } G_LOGGER.warning( "It looks like this model contains foldable nodes that produce large outputs.\n" "In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n" - f"Note: Large tensors and their corresponding sizes were: {large_tensors_mib}", + "Note: Large tensors and their corresponding sizes were: {:}".format(large_tensors_mib), mode=LogMode.ONCE, ) @@ -1273,12 +1283,12 @@ def fold_subgraphs(): while index < len(self.nodes): node = self.nodes[index] if node.op == "If" and isinstance(node.inputs[0], Constant): - G_LOGGER.debug(f"Flattening conditional: {node.name}") + G_LOGGER.debug("Flattening conditional: {:}".format(node.name)) cond = get_scalar_value(node.inputs[0]) subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"] # Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors for tensor in subgraph._local_tensors().values(): - tensor.name += f"_subg_{index}_{subgraph.name}" + tensor.name += "_subg_{:}_{:}".format(index, subgraph.name) # The subgraph outputs correspond to the If node outputs. Only the latter are visible # in the parent graph, so we rebind the producer nodes of the subgraph outputs to point @@ -1387,9 +1397,9 @@ def process_io(io, existing_names): new_io.append(Constant(name=name, values=arr)) else: G_LOGGER.critical( - f"Unrecognized type passed to Graph.layer: {elem}.\n" + "Unrecognized type passed to Graph.layer: {:}.\n" "\tHint: Did you forget to unpack a list with `*`?\n" - "\tPlease use Tensors, strings, or NumPy arrays." + "\tPlease use Tensors, strings, or NumPy arrays.".format(elem) ) if new_io[-1].name: existing_names.add(new_io[-1].name) diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py index 4fc7419..88437ed 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py @@ -24,9 +24,7 @@ from onnxslim.third_party.onnx_graphsurgeon.util import misc -class Node: - """Represents an operation node in a computational graph, managing inputs, outputs, and attributes.""" - +class Node(object): @dataclass class AttributeRef: """ @@ -179,24 +177,24 @@ def copy( def __str__(self): """Return a string representation of the object showing its name and operation.""" - ret = f"{self.name} ({self.op})" + ret = "{:} ({:})".format(self.name, self.op) def add_io(name, io): """Add the input or output operations and their names to the string representation of the object.""" nonlocal ret - ret += f"\n\t{name}: [" + ret += "\n\t{:}: [".format(name) for elem in io: - ret += f"\n\t\t{elem}" + ret += "\n\t\t{:}".format(elem) ret += "\n\t]" add_io("Inputs", self.inputs) add_io("Outputs", self.outputs) if self.attrs: - ret += f"\nAttributes: {self.attrs}" + ret += "\nAttributes: {:}".format(self.attrs) if self.domain: - ret += f"\nDomain: {self.domain}" + ret += "\nDomain: {:}".format(self.domain) return ret @@ -206,7 +204,7 @@ def __repr__(self): def __eq__(self, other): """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs.""" - G_LOGGER.verbose(f"Comparing node: {self.name} with {other.name}") + G_LOGGER.verbose("Comparing node: {:} with {:}".format(self.name, other.name)) attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs if not attrs_match: return False diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py b/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py index 51b3477..d68270a 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py @@ -23,7 +23,7 @@ from onnxslim.third_party.onnx_graphsurgeon.util import misc -class Tensor: +class Tensor(object): """Abstract base class for tensors in a graph.""" DYNAMIC = -1 @@ -155,7 +155,7 @@ def o(self, consumer_idx=0, tensor_idx=0): def __str__(self): """Returns a string representation of the object including its type, name, shape, and data type.""" - return f"{type(self).__name__} ({self.name}): (shape={self.shape}, dtype={self.dtype})" + return "{:} ({:}): (shape={:}, dtype={:})".format(type(self).__name__, self.name, self.shape, self.dtype) def __repr__(self): # Hack to make logging output pretty. """Returns a string representation of the object for logging output.""" @@ -191,8 +191,6 @@ def is_output(self, is_output: bool = False): class Variable(Tensor): - """Represents a tensor with unknown values until inference-time, supporting dynamic shapes and data types.""" - @staticmethod def empty(): """Create and return an empty Variable tensor with an empty name.""" @@ -260,7 +258,7 @@ def __eq__(self, other): return name_match and inputs_match and outputs_match and dtype_match and shape_match and type_match -class LazyValues: +class LazyValues(object): """A special object that represents constant tensor values that should be lazily loaded.""" def __init__(self, tensor): @@ -306,7 +304,7 @@ def load(self): def __str__(self): """Returns a formatted string representation of the LazyValues object indicating its shape and dtype.""" - return f"LazyValues (shape={self.shape}, dtype={self.dtype})" + return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype) def __repr__(self): # Hack to make logging output pretty. """Returns an unambiguous string representation of the LazyValues object for logging purposes.""" @@ -371,12 +369,10 @@ def load(self): def __str__(self): """Return a string representation of the SparseValues object with its shape and data type.""" - return f"SparseValues (shape={self.shape}, dtype={self.dtype})" + return "SparseValues (shape={:}, dtype={:})".format(self.shape, self.dtype) class Constant(Tensor): - """Represents a tensor with known constant values, supporting lazy loading and export data type specification.""" - def __init__( self, name: str, @@ -411,7 +407,7 @@ def __init__( G_LOGGER.critical( "Provided `values` argument is not a NumPy array, a LazyValues instance or a" "SparseValues instance. Please provide a NumPy array or LazyValues instance " - f"to construct a Constant. Note: Provided `values` parameter was: {values}" + "to construct a Constant. Note: Provided `values` parameter was: {:}".format(values) ) self._values = values self.data_location = data_location @@ -474,7 +470,7 @@ def export_dtype(self, export_dtype): def __repr__(self): # Hack to make logging output pretty. """Return a string representation of the object, including its values, for improved logging readability.""" ret = self.__str__() - ret += f"\n{self._values}" + ret += "\n{:}".format(self._values) return ret def __eq__(self, other): diff --git a/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py b/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py index f4f735b..03d0d62 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +++ b/onnxslim/third_party/onnx_graphsurgeon/logger/logger.py @@ -27,9 +27,7 @@ # Context manager to apply indentation to messages -class LoggerIndent: - """Context manager for temporarily setting indentation levels in logger messages.""" - +class LoggerIndent(object): def __init__(self, logger, indent): """Initialize the LoggerIndent context manager with the specified logger and indentation level.""" self.logger = logger @@ -47,9 +45,7 @@ def __exit__(self, exc_type, exc_value, traceback): # Context manager to suppress messages -class LoggerSuppress: - """Suppress logger messages below a specified severity level within a context.""" - +class LoggerSuppress(object): def __init__(self, logger, severity): """Initialize a LoggerSuppress object with a logger and severity level.""" self.logger = logger @@ -67,15 +63,11 @@ def __exit__(self, exc_type, exc_value, traceback): class LogMode(enum.IntEnum): - """Enumerates logging modes for controlling message frequency in the Ultralytics library.""" - EACH = 0 # Log the message each time ONCE = 1 # Log the message only once. The same message will not be logged again. -class Logger: - """Manages logging with configurable severity, indentation, and formatting for debugging and monitoring.""" - +class Logger(object): ULTRA_VERBOSE = -10 VERBOSE = 0 DEBUG = 10 @@ -181,7 +173,7 @@ def get_line_info(): # If the file is not located in trt_smeagol, use its basename instead. if os.pardir in filename: filename = os.path.basename(filename) - return f"[{filename}:{sys._getframe(stack_depth).f_lineno}] " + return "[{:}:{:}] ".format(filename, sys._getframe(stack_depth).f_lineno) prefix = "" if self.letter: @@ -215,7 +207,7 @@ def apply_color(message): prefix = get_prefix() message = apply_indentation(message) - return apply_color(f"{prefix}{message}") + return apply_color("{:}{:}".format(prefix, message)) def should_log(message): """Determines if a message should be logged based on the severity level and logging mode.""" diff --git a/onnxslim/third_party/onnx_graphsurgeon/util/misc.py b/onnxslim/third_party/onnx_graphsurgeon/util/misc.py index af03b45..1ac98cd 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/util/misc.py +++ b/onnxslim/third_party/onnx_graphsurgeon/util/misc.py @@ -158,8 +158,6 @@ def convert_to_onnx_attr_type(any_type): # So, in the example above, we can make n.inputs a synchronized list whose field_name is set to "outputs". # See test_ir.TestNodeIO for functional tests class SynchronizedList(list): - """Synchronizes list operations with a specified attribute of elements to maintain bidirectional consistency.""" - def __init__(self, parent_obj, field_name, initial): """Initialize a SynchronizedList with a parent object, a field name, and an initial set of elements.""" self.parent_obj = parent_obj diff --git a/onnxslim/third_party/symbolic_shape_infer.py b/onnxslim/third_party/symbolic_shape_infer.py index 6aace36..aff28c7 100644 --- a/onnxslim/third_party/symbolic_shape_infer.py +++ b/onnxslim/third_party/symbolic_shape_infer.py @@ -138,8 +138,6 @@ def sympy_reduce_product(x): class SymbolicShapeInference: - """Performs symbolic shape inference on ONNX models to deduce tensor shapes using symbolic computation.""" - def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): """Initializes the SymbolicShapeInference class with configuration parameters for symbolic shape inference.""" self.dispatcher_ = { @@ -2399,7 +2397,6 @@ def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_MultiScaleDeformableAttnTRT(self, node): - """Infers output shape and type for MultiScaleDeformableAttnTRT node using input shapes.""" shape_value = self._try_get_shape(node, 0) sampling_locations = self._try_get_shape(node, 3) output_shape = shape_value @@ -2838,8 +2835,8 @@ def get_prereq(node): ): sorted_known_vi.update(node.output) sorted_nodes.append(node) - if old_sorted_nodes_len == len(sorted_nodes) and any( - o.name not in sorted_known_vi for o in self.out_mp_.graph.output + if old_sorted_nodes_len == len(sorted_nodes) and not all( + o.name in sorted_known_vi for o in self.out_mp_.graph.output ): raise Exception("Invalid model with cyclic graph") @@ -2939,7 +2936,11 @@ def get_prereq(node): out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED if self.verbose_ > 2: logger.debug( - f" {node.output[i_o]}: {str(out_shape)} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}" + " {}: {} {}".format( + node.output[i_o], + str(out_shape), + onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), + ) ) if node.output[i_o] in self.sympy_data_: logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) @@ -3042,11 +3043,17 @@ def get_prereq(node): if self.verbose_ > 0: if is_unknown_op: logger.debug( - f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape" + "Possible unknown op: {} node: {}, guessing {} shape".format( + node.op_type, node.name, vi.name + ) ) if self.verbose_ > 2: logger.debug( - f" {node.output[i_o]}: {str(new_shape)} {vi.type.tensor_type.elem_type}" + " {}: {} {}".format( + node.output[i_o], + str(new_shape), + vi.type.tensor_type.elem_type, + ) ) self.run_ = True continue # continue the inference after guess, no need to stop as no merge is needed diff --git a/onnxslim/utils.py b/onnxslim/utils.py index b02ff60..ffcdead 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -56,7 +56,7 @@ def format_bytes(size: Union[int, Tuple[int, ...]]) -> str: size_in_bytes /= 1024 unit_index += 1 - formatted_size = f"{size_in_bytes:.2f} {units[unit_index]}" + formatted_size = "{:.2f} {}".format(size_in_bytes, units[unit_index]) formatted_sizes.append(formatted_size) if len(formatted_sizes) == 1: @@ -579,7 +579,6 @@ def check_onnx_compatibility(): def get_max_tensor(model, topk=5): - """Identify and print the top-k largest constant tensors in an ONNX model based on their size.""" graph = gs.import_onnx(model) tensor_map = graph.tensors() diff --git a/setup.py b/setup.py index 43f0454..b760e5f 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import find_packages, setup -with open("VERSION") as f: +with open("VERSION", "r") as f: version = f.read().strip() with open("onnxslim/version.py", "w") as f: @@ -10,7 +10,7 @@ name="onnxslim", version=version, description="OnnxSlim: A Toolkit to Help Optimize Large Onnx Model", - long_description=open("README.md", encoding="utf-8").read(), + long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", url="https://github.com/inisis/OnnxSlim", author="inisis", diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index c479f15..f5d5933 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -12,10 +12,7 @@ class TestModelZoo: - """Tests ONNX models from the model zoo using slimming techniques for validation.""" - def test_silero_vad(self, request): - """Test the Silero VAD model by slimming its ONNX file and running inference with dummy input data.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" @@ -30,7 +27,6 @@ def test_silero_vad(self, request): ort_sess.run(None, {"input": input, "sr": sr, "state": state}) def test_decoder_with_past_model(self, request): - """Test the ONNX model decoder with past states using a slimmed model and validate inference execution.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" @@ -44,7 +40,6 @@ def test_decoder_with_past_model(self, request): ort_sess.run(None, {"input_ids": input_ids, "encoder_hidden_states": encoder_hidden_states}) def test_tiny_en_decoder(self, request): - """Tests the functionality of a slimmed tiny English encoder-decoder model using ONNX Runtime for inference.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" @@ -52,7 +47,6 @@ def test_tiny_en_decoder(self, request): slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True) def test_transformer_encoder(self, request): - """Tests the transformer encoder model from the model zoo by verifying the operation count after slimming.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" summary = summarize_model(slim(filename)) @@ -61,7 +55,6 @@ def test_transformer_encoder(self, request): assert summary["op_type_counts"]["Div"] == 53 def test_uiex(self, request): - """Summarize the UIEX model and verify absence of 'Range' and 'Floor' operators.""" name = request.node.originalname[len("test_") :] filename = f"{MODELZOO_PATH}/{name}/{name}.onnx" summary = summarize_model(slim(filename)) diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index 907ae10..9daf160 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -24,7 +24,6 @@ class TestTorchVisionClass: models.googlenet, ), ) - """Tests TorchVision models by exporting them to ONNX format and verifying the process with random input tensors.""" def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): """Test various TorchVision models with random input tensors of a specified shape.""" model = model(pretrained=PRETRAINED) @@ -48,7 +47,6 @@ def test_torchvision(self, request, model, shape=(1, 3, 224, 224)): class TestTimmClass: - """Tests TIMM models for successful ONNX export and slimming using random input tensors.""" @pytest.fixture(params=timm.list_models()) def model_name(self, request): """Yields names of models available in TIMM (https://github.com/rwightman/pytorch-image-models) for pytest fixture parameterization.""" diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index 3d977cd..8af9738 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -12,8 +12,6 @@ class TestFunctional: - """Tests the functionality of the 'slim' function for optimizing ONNX models using temporary directories.""" - def test_basic(self, request): """Test the basic functionality of the slim function.""" with tempfile.TemporaryDirectory() as tempdir: @@ -32,8 +30,6 @@ def test_basic(self, request): class TestFeature: - """Tests ONNX model modifications like input shape, precision conversion, and input/output adjustments.""" - def test_input_shape_modification(self, request): """Test the modification of input shapes.""" summary = summarize_model(slim(FILENAME, input_shapes=["input:1,3,224,224"])) @@ -80,9 +76,7 @@ def test_output_modification(self, request): def test_input_modification(self, request): """Tests input modification.""" - summary = summarize_model( - slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"]) - ) + summary = summarize_model(slim(FILENAME, inputs=["/maxpool/MaxPool_output_0", "/layer1/layer1.0/relu/Relu_output_0"])) print_model_info_as_table(request.node.name, summary) assert "/maxpool/MaxPool_output_0" in summary["op_input_info"] assert "/layer1/layer1.0/relu/Relu_output_0" in summary["op_input_info"] diff --git a/tests/test_pattern_generator.py b/tests/test_pattern_generator.py index a3c9f63..ae0aca9 100644 --- a/tests/test_pattern_generator.py +++ b/tests/test_pattern_generator.py @@ -10,14 +10,12 @@ class TestPatternGenerator: - """Generates and tests ONNX fusion patterns for neural network models using the GELU activation function.""" - def test_gelu(self, request): """Test the GELU activation function within the PatternModel class.""" class PatternModel(nn.Module): def __init__(self): - super().__init__() + super(PatternModel, self).__init__() self.gelu = nn.GELU() def forward(self, x): @@ -28,7 +26,7 @@ def forward(self, x): class Model(nn.Module): def __init__(self): """Initializes the Model class with ReLU and PatternModel components.""" - super().__init__() + super(Model, self).__init__() self.relu0 = nn.ReLU() self.pattern = PatternModel() self.relu1 = nn.ReLU() diff --git a/tests/test_pattern_matcher.py b/tests/test_pattern_matcher.py index dd22fa2..70dc2bf 100644 --- a/tests/test_pattern_matcher.py +++ b/tests/test_pattern_matcher.py @@ -9,14 +9,12 @@ class TestPatternMatcher: - """Tests various neural network operations by exporting PyTorch models to ONNX and analyzing them with onnxslim.""" - def test_gelu(self, request): """Test the GELU activation function in a neural network model using an instance of nn.Module.""" class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() self.relu0 = nn.ReLU() self.gelu = nn.GELU() self.relu1 = nn.ReLU() @@ -46,7 +44,7 @@ def test_pad_conv(self, request): class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() self.pad_0 = nn.ConstantPad2d(3, 0) self.conv_0 = nn.Conv2d(1, 1, 3) @@ -82,7 +80,7 @@ def test_conv_bn(self, request): class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() self.conv = nn.Conv2d(1, 1, 3) self.bn = nn.BatchNorm2d(1) @@ -111,7 +109,7 @@ def test_consecutive_slice(self, request): class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() self.conv = nn.Conv2d(1, 1, 3) self.bn = nn.BatchNorm2d(1) @@ -136,7 +134,7 @@ def test_consecutive_reshape(self, request): class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() def forward(self, x): """Reshape tensor sequentially to (2, 6) and then to (12, 1).""" @@ -159,7 +157,7 @@ def test_matmul_add(self, request): class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() self.data = torch.randn(4, 3) def forward(self, x): @@ -187,7 +185,7 @@ def test_reduce(self, request): class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() def forward(self, x): """Performs a reduction summing over the last dimension of the input tensor and then unsqueezes the @@ -217,11 +215,9 @@ def forward(self, x): ), ) def test_consecutive_unsqueeze(self, request, opset): - """Tests consecutive unsqueeze operations in a model by exporting to ONNX and summarizing the slimmed model.""" - class Model(nn.Module): def __init__(self): - super().__init__() + super(Model, self).__init__() def forward(self, x): x = x.unsqueeze(-1) diff --git a/tests/utils.py b/tests/utils.py index 86fdffe..f885345 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,7 +23,7 @@ from tqdm import tqdm except ImportError: # fake tqdm if it's not installed - class tqdm: # type: ignore[no-redef] + class tqdm(object): # type: ignore[no-redef] def __init__( self, total=None, @@ -44,9 +44,9 @@ def update(self, n): self.n += n if self.total is None: - sys.stderr.write(f"\r{self.n:.1f} bytes") + sys.stderr.write("\r{0:.1f} bytes".format(self.n)) else: - sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%") + sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) sys.stderr.flush() def close(self):