Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simple optimizations #31

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions onnxslim/core/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,10 @@ def find_matches(graph: Graph, fusion_patterns: dict):

def get_previous_node_by_type(node, op_type, trajectory=None):
"""Recursively find and return the first preceding node of a specified type in the computation graph."""
if trajectory is None:
trajectory = []
node_feeds = get_node_feeds(node)
if trajectory is None:
trajectory = [node_feed for node_feed in node_feeds]
for node_feed in node_feeds:
trajectory.append(node_feed)
if node_feed.op == op_type:
return trajectory
else:
return get_previous_node_by_type(node_feed, op_type, trajectory)
return get_previous_node_by_type(node_feed, op_type, trajectory)
5 changes: 1 addition & 4 deletions onnxslim/core/optimization/subexpression_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ def find_and_remove_replaceable_nodes(nodes):
"""Find and remove duplicate or replaceable nodes in a given list of computational graph nodes."""

def get_node_key(node):
input_names = []
for input_node in node.inputs:
if isinstance(input_node, Variable):
input_names.append(input_node.name)
input_names = [input_node.name for input_node in node.inputs if isinstance(input_node, Variable)]
return "_".join(input_names) if input_names else None

def replace_node_references(existing_node, to_be_removed_node):
Expand Down
2 changes: 1 addition & 1 deletion onnxslim/core/optimization/weight_tying.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def tie_weights(graph):

sub_graphs = graph.subgraphs(recursive=True)
sub_graphs_constant_tensors = [
[tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
[tensor for _, tensor in sub_graph.tensors().items() if isinstance(tensor, gs.Constant)]
for sub_graph in sub_graphs
]

Expand Down
8 changes: 2 additions & 6 deletions onnxslim/core/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def get_input_info(io_spec):
pattern_with_plus = re.search(r"(\d+)(\+)", io_spec)
if pattern_with_plus:
return int(pattern_with_plus[1]), True
else:
raise ValueError(f"input_num and output_num must be integers {io_spec}")
raise ValueError(f"input_num and output_num must be integers {io_spec}")

return int(io_spec), False

Expand Down Expand Up @@ -154,16 +153,13 @@ def match_(node, pattern_node):
return False

pattern_nodes = [self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names]
all_match = True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend keep this variable to make it clear

for node_feed, pattern_node in zip(node_feeds, pattern_nodes):
if pattern_node is not None:
node_match = match_(node_feed, pattern_node)
if not node_match:
return False
setattr(self, pattern_node.name, node_feed)

return all_match

return True
return False

if match_(node, match_point):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,7 @@ def process_attr(attr_str: str):
if attr_str in ONNX_PYTHON_ATTR_MAPPING:
attr_dict[attr.name] = process_attr(attr_str)
else:
G_LOGGER.warning(
f"Attribute of type {attr_str} is currently unsupported. Skipping attribute."
)
G_LOGGER.warning(f"Attribute of type {attr_str} is currently unsupported. Skipping attribute.")
else:
G_LOGGER.warning(
f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute."
Expand Down
4 changes: 1 addition & 3 deletions onnxslim/third_party/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,7 @@ def add_to_tensor_map(tensor):
if not tensor.is_empty():
if tensor.name in tensor_map and tensor_map[tensor.name] is not tensor:
msg = f"Found distinct tensors that share the same name:\n[id: {id(tensor_map[tensor.name])}] {tensor_map[tensor.name]}\n[id: {id(tensor)}] {tensor}\n"
msg += (
f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}"
)
msg += f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}"

if check_duplicates:
G_LOGGER.critical(msg)
Expand Down
1 change: 0 additions & 1 deletion tests/test_onnx_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
import timm
import torch
from torch.utils.data import RandomSampler
import torchvision.models as models

FUSE = True
Expand Down
31 changes: 21 additions & 10 deletions tests/test_yolo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from itertools import product

import pytest
import ultralytics


import gc
import json
import os
Expand All @@ -13,11 +7,13 @@
import warnings
from copy import deepcopy
from datetime import datetime
from itertools import product
from pathlib import Path

import numpy as np
import pytest
import torch

import ultralytics
from ultralytics.cfg import TASK2DATA, get_cfg
from ultralytics.data import build_dataloader
from ultralytics.data.dataset import YOLODataset
Expand All @@ -41,11 +37,25 @@
get_default_args,
yaml_save,
)
from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requirements, check_version
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
from ultralytics.utils.checks import (
check_imgsz,
check_is_path_safe,
check_requirements,
check_version,
)
from ultralytics.utils.downloads import (
attempt_download_asset,
get_github_assets,
safe_download,
)
from ultralytics.utils.files import file_size, spaces_in_path
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device, smart_inference_mode
from ultralytics.utils.torch_utils import (
TORCH_1_13,
get_latest_opset,
select_device,
smart_inference_mode,
)


def export_formats():
Expand Down Expand Up @@ -1152,6 +1162,7 @@ def forward(self, x):

import ultralytics.engine
import ultralytics.engine.exporter

ultralytics.engine.exporter.Exporter = Exporter
from ultralytics import YOLO
from ultralytics.cfg import TASK2MODEL, TASKS
Expand Down