Skip to content

Commit

Permalink
add replace_all_uses_with api (#57)
Browse files Browse the repository at this point in the history
* add replace_all_uses_with api

* Auto-format by https://ultralytics.com/actions

* refactor code

* Auto-format by https://ultralytics.com/actions

* fix typo

* fix bugs

* fix bug

* update modelzoo test

* fix when node is output

---------

Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
inisis and UltralyticsAssistant authored Dec 12, 2024
1 parent 4b9731c commit fc17033
Show file tree
Hide file tree
Showing 16 changed files with 96 additions and 99 deletions.
5 changes: 2 additions & 3 deletions onnxslim/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.optimization import optimize_model
from onnxslim.core.utils import delete_node
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable
from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference
from onnxslim.utils import save

Expand Down Expand Up @@ -173,7 +172,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
if node.op == "Cast":
inp_dtype = [input.dtype for input in node.inputs][0]
if inp_dtype in [np.float16, np.float32]:
delete_node(node)
node.replace_all_uses_with(node.inputs[0])
else:
outp_dtype = [output.dtype for output in node.outputs][0]
if outp_dtype == np.float16:
Expand Down
3 changes: 1 addition & 2 deletions onnxslim/core/optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import onnx

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import get_node_feeds
from onnxslim.core.pattern.registry import get_fusion_patterns
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph

Expand Down Expand Up @@ -81,7 +80,7 @@ def get_previous_node_by_type(node, op_type, trajectory=None):
"""Recursively find and return the first preceding node of a specified type in the computation graph."""
if trajectory is None:
trajectory = []
node_feeds = get_node_feeds(node)
node_feeds = node.feeds
for node_feed in node_feeds:
trajectory.append(node_feed)
if node_feed.op == op_type:
Expand Down
31 changes: 15 additions & 16 deletions onnxslim/core/optimization/dead_node_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.utils import delete_node
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable

Expand All @@ -20,28 +19,28 @@ def dead_node_elimination(graph, is_subgraph=False):
for node in graph.nodes:
if node.op in {"Identity", "Dropout"}:
if not is_subgraph:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Pad":
if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
pad_value = node.inputs[1].values.tolist()
pad_value = pad_value if isinstance(pad_value, list) else [pad_value]
if all(value == 0 for value in pad_value):
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Cast":
inp_dtype = [dtype_to_onnx(input.dtype) for input in node.inputs][0]
if inp_dtype == node.attrs["to"]:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Reshape":
if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and (
node.outputs[0].shape and len(node.outputs[0].shape) == 1
):
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
else:
node_output_shape = node.outputs[0].shape
Expand All @@ -61,7 +60,7 @@ def dead_node_elimination(graph, is_subgraph=False):
idx, constant_variable = get_constant_variable(node, return_idx=True)
if np.all(constant_variable.values == 1):
var_idx = 0 if idx == 1 else 1
delete_node(node, var_idx)
node.replace_all_uses_with(node.feeds[var_idx])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Add":
if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
Expand All @@ -71,25 +70,25 @@ def dead_node_elimination(graph, is_subgraph=False):
value = constant_variable.values
var_idx = 0 if idx == 1 else 1
if value.ndim == 0 and value == 0:
delete_node(node, var_idx)
node.replace_all_uses_with(node.feeds[var_idx])
logger.debug(f"removing {node.op} op: {node.name}")
elif np.all(value == 0) and (node.inputs[var_idx].shape == node.outputs[0].shape):
delete_node(node, var_idx)
node.replace_all_uses_with(node.feeds[var_idx])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Expand":
# tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256]
if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant):
constant_variable = node.inputs[1]
value = constant_variable.values
if node.inputs[0].shape == node.outputs[0].shape:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif value.ndim == 0 and value == 1:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Concat":
if len(node.inputs) == 1:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
else:
for input in node.inputs:
Expand All @@ -100,20 +99,20 @@ def dead_node_elimination(graph, is_subgraph=False):
constant_variable = node.inputs[1]
value = constant_variable.values
if value.ndim == 0 and value == 0:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif np.all(value == 0) and (node.inputs[0].shape == node.outputs[0].shape):
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif node.op == "Div":
if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable):
constant_variable = node.inputs[1]
value = constant_variable.values
if value.ndim == 0 and value == 1:
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")
elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape):
delete_node(node)
node.replace_all_uses_with(node.feeds[0])
logger.debug(f"removing {node.op} op: {node.name}")


Expand Down
3 changes: 1 addition & 2 deletions onnxslim/core/optimization/subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging

from onnxslim.core.pattern import get_node_users
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable

logger = logging.getLogger("onnxslim")
Expand All @@ -17,7 +16,7 @@ def get_node_key(node):
return "_".join(input_names) if input_names else None

def replace_node_references(existing_node, to_be_removed_node):
users = get_node_users(to_be_removed_node)
users = to_be_removed_node.users
for user in users:
for inp in user.inputs:
if inp in to_be_removed_node.outputs:
Expand Down
29 changes: 3 additions & 26 deletions onnxslim/core/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,6 @@
logger = logging.getLogger("onnxslim")


def get_node_users(node):
"""Retrieve the list of nodes that use the outputs of the given node."""
users = []
for output in node.outputs: # output is a Variable
if output.is_output:
users.append(output)
users.extend(iter(output.outputs))
return users


def get_node_feeds(node):
"""Retrieve the list of nodes that provide inputs to the given node."""
feeds = []
for input in node.inputs:
if len(input.inputs) == 0 and not isinstance(input, Constant):
feeds.append(input)
elif isinstance(input, Constant):
feeds.append(input)
else:
feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs)
return feeds


def get_name(name):
"""Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if
numeric.
Expand Down Expand Up @@ -142,7 +119,7 @@ def match_(node, pattern_node):
if node.op == pattern_node.op:
setattr(self, pattern_node.name, node)

node_feeds = get_node_feeds(node)
node_feeds = node.feeds
if pattern_node.coarse_input_num:
if len(node_feeds) < len(pattern_node.input_names):
return False
Expand Down Expand Up @@ -207,8 +184,8 @@ def generate(self):
for node in nodes:
if node.op != "Constant":
name = get_name(node.name)
feeds = get_node_feeds(node)
users = get_node_users(node)
feeds = node.feeds
users = node.users
template.append(
" ".join(
[node.op, name, str(len(feeds)), str(len(users))]
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/elimination/reshape.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand Down Expand Up @@ -33,7 +33,7 @@ def rewrite(self, opset=11):
node = self.reshape_1
first_reshape_node = node.i(0)
first_reshape_node_inputs = list(first_reshape_node.inputs)
first_reshape_node_users = get_node_users(first_reshape_node)
first_reshape_node_users = first_reshape_node.users
if len(first_reshape_node_users) == 1:
second_reshape_node = node

Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/elimination/slice.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand Down Expand Up @@ -29,7 +29,7 @@ def rewrite(self, opset=11):
first_slice_node = self.slice_0
first_slice_node_inputs = list(first_slice_node.inputs)
if all(isinstance(input, gs.Constant) for input in first_slice_node_inputs[1:]):
first_slice_node_users = get_node_users(first_slice_node)
first_slice_node_users = first_slice_node.users
if all(
user.op == "Slice" and all(isinstance(input, gs.Constant) for input in list(user.inputs)[1:])
for user in first_slice_node_users
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/elimination/unsqueeze.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand All @@ -27,7 +27,7 @@ def rewrite(self, opset=11):
"""Rewrites an elimination pattern for unsqueeze nodes by optimizing nested slice operations."""
match_case = {}
node_unsqueeze_0 = self.unsqueeze_0
users_node_unsqueeze_0 = get_node_users(node_unsqueeze_0)
users_node_unsqueeze_0 = node_unsqueeze_0.users
node_unsqueeze_1 = self.unsqueeze_1
if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape:
if opset < 13 or (
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/fusion/convadd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand All @@ -25,7 +25,7 @@ def rewrite(self, opset=11):
match_case = {}
conv_node = self.conv_0
conv_weight = list(conv_node.inputs)[1]
conv_node_users = get_node_users(conv_node)
conv_node_users = conv_node.users
node = self.add_0
if (
len(conv_node_users) == 1
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/fusion/convbn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand All @@ -27,7 +27,7 @@ def rewrite(self, opset=11):
"""Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer."""
match_case = {}
conv_transpose_node = self.conv_0
conv_transpose_node_users = get_node_users(conv_transpose_node)
conv_transpose_node_users = conv_transpose_node.users
node = self.bn_0
if len(conv_transpose_node_users) == 1:
conv_transpose_weight = conv_transpose_node.inputs[1].values
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/fusion/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.optimization.dead_node_elimination import get_constant_variable
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand Down Expand Up @@ -35,7 +35,7 @@ def rewrite(self, opset=11):
input_variable = (
matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], gs.Constant) else matmul_node.inputs[1]
)
users = get_node_users(matmul_node)
users = matmul_node.users
if len(users) == 1 and matmul_bias_variable and len(matmul_bias_variable.shape) == 2:
if (
input_variable.shape
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/fusion/padconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import onnxslim.third_party.onnx_graphsurgeon as gs
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand Down Expand Up @@ -33,7 +33,7 @@ def rewrite(self, opset=11):
match_case = {}
conv_node = self.conv_0
pad_node = self.pad_0
pad_node_users = get_node_users(pad_node)
pad_node_users = pad_node.users

pad_inputs = len(pad_node.inputs)
if pad_inputs < 3 or (
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/core/pattern/fusion/reduce.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users
from onnxslim.core.pattern import Pattern, PatternMatcher
from onnxslim.core.pattern.registry import register_fusion_pattern


Expand All @@ -25,7 +25,7 @@ def rewrite(self, opset=11):
match_case = {}
node = self.unsqueeze_0
reduce_node = self.reduce_0
reduce_node_node_users = get_node_users(reduce_node)
reduce_node_node_users = reduce_node.users
if len(reduce_node_node_users) == 1:
unsqueeze_node = node

Expand Down
32 changes: 0 additions & 32 deletions onnxslim/core/utils.py

This file was deleted.

Loading

0 comments on commit fc17033

Please sign in to comment.