diff --git a/onnxslim/core/optimization/__init__.py b/onnxslim/core/optimization/__init__.py index d413ffd..b065d3d 100644 --- a/onnxslim/core/optimization/__init__.py +++ b/onnxslim/core/optimization/__init__.py @@ -79,12 +79,10 @@ def find_matches(graph: Graph, fusion_patterns: dict): def get_previous_node_by_type(node, op_type, trajectory=None): """Recursively find and return the first preceding node of a specified type in the computation graph.""" - if trajectory is None: - trajectory = [] node_feeds = get_node_feeds(node) + if trajectory is None: + trajectory = [node_feed for node_feed in node_feeds] for node_feed in node_feeds: - trajectory.append(node_feed) if node_feed.op == op_type: return trajectory - else: - return get_previous_node_by_type(node_feed, op_type, trajectory) + return get_previous_node_by_type(node_feed, op_type, trajectory) diff --git a/onnxslim/core/optimization/subexpression_elimination.py b/onnxslim/core/optimization/subexpression_elimination.py index 23215ae..bc331b0 100644 --- a/onnxslim/core/optimization/subexpression_elimination.py +++ b/onnxslim/core/optimization/subexpression_elimination.py @@ -10,10 +10,7 @@ def find_and_remove_replaceable_nodes(nodes): """Find and remove duplicate or replaceable nodes in a given list of computational graph nodes.""" def get_node_key(node): - input_names = [] - for input_node in node.inputs: - if isinstance(input_node, Variable): - input_names.append(input_node.name) + input_names = [input_node.name for input_node in node.inputs if isinstance(input_node, Variable)] return "_".join(input_names) if input_names else None def replace_node_references(existing_node, to_be_removed_node): diff --git a/onnxslim/core/optimization/weight_tying.py b/onnxslim/core/optimization/weight_tying.py index 20c8f1a..09a902b 100644 --- a/onnxslim/core/optimization/weight_tying.py +++ b/onnxslim/core/optimization/weight_tying.py @@ -11,7 +11,7 @@ def tie_weights(graph): sub_graphs = graph.subgraphs(recursive=True) sub_graphs_constant_tensors = [ - [tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)] + [tensor for _, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)] for sub_graph in sub_graphs ] diff --git a/onnxslim/core/pattern/__init__.py b/onnxslim/core/pattern/__init__.py index 32bf0d5..cfcf1bb 100644 --- a/onnxslim/core/pattern/__init__.py +++ b/onnxslim/core/pattern/__init__.py @@ -59,8 +59,7 @@ def get_input_info(io_spec): pattern_with_plus = re.search(r"(\d+)(\+)", io_spec) if pattern_with_plus: return int(pattern_with_plus[1]), True - else: - raise ValueError(f"input_num and output_num must be integers {io_spec}") + raise ValueError(f"input_num and output_num must be integers {io_spec}") return int(io_spec), False @@ -154,16 +153,13 @@ def match_(node, pattern_node): return False pattern_nodes = [self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names] - all_match = True for node_feed, pattern_node in zip(node_feeds, pattern_nodes): if pattern_node is not None: node_match = match_(node_feed, pattern_node) if not node_match: return False setattr(self, pattern_node.name, node_feed) - - return all_match - + return True return False if match_(node, match_point): diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py index 0e98eea..c1dcbbf 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py @@ -277,9 +277,7 @@ def process_attr(attr_str: str): if attr_str in ONNX_PYTHON_ATTR_MAPPING: attr_dict[attr.name] = process_attr(attr_str) else: - G_LOGGER.warning( - f"Attribute of type {attr_str} is currently unsupported. Skipping attribute." - ) + G_LOGGER.warning(f"Attribute of type {attr_str} is currently unsupported. Skipping attribute.") else: G_LOGGER.warning( f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute." diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py index b8ad85f..414c47d 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py @@ -639,9 +639,7 @@ def add_to_tensor_map(tensor): if not tensor.is_empty(): if tensor.name in tensor_map and tensor_map[tensor.name] is not tensor: msg = f"Found distinct tensors that share the same name:\n[id: {id(tensor_map[tensor.name])}] {tensor_map[tensor.name]}\n[id: {id(tensor)}] {tensor}\n" - msg += ( - f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}" - ) + msg += f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}" if check_duplicates: G_LOGGER.critical(msg) diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index 824712c..1fdb578 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -6,7 +6,6 @@ import pytest import timm import torch -from torch.utils.data import RandomSampler import torchvision.models as models FUSE = True diff --git a/tests/test_yolo.py b/tests/test_yolo.py index 66b4ba4..ef7123c 100644 --- a/tests/test_yolo.py +++ b/tests/test_yolo.py @@ -1,9 +1,3 @@ -from itertools import product - -import pytest -import ultralytics - - import gc import json import os @@ -13,11 +7,13 @@ import warnings from copy import deepcopy from datetime import datetime +from itertools import product from pathlib import Path import numpy as np +import pytest import torch - +import ultralytics from ultralytics.cfg import TASK2DATA, get_cfg from ultralytics.data import build_dataloader from ultralytics.data.dataset import YOLODataset @@ -41,11 +37,25 @@ get_default_args, yaml_save, ) -from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version -from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download +from ultralytics.utils.checks import ( + check_imgsz, + check_is_path_safe, + check_requirements, + check_version, +) +from ultralytics.utils.downloads import ( + attempt_download_asset, + get_github_assets, + safe_download, +) from ultralytics.utils.files import file_size, spaces_in_path from ultralytics.utils.ops import Profile -from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode +from ultralytics.utils.torch_utils import ( + TORCH_1_13, + get_latest_opset, + select_device, + smart_inference_mode, +) def export_formats(): @@ -1152,6 +1162,7 @@ def forward(self, x): import ultralytics.engine import ultralytics.engine.exporter + ultralytics.engine.exporter.Exporter = Exporter from ultralytics import YOLO from ultralytics.cfg import TASK2MODEL, TASKS