diff --git a/onnxslim/core/optimization/dead_node_elimination.py b/onnxslim/core/optimization/dead_node_elimination.py index 61bdea4..6985124 100644 --- a/onnxslim/core/optimization/dead_node_elimination.py +++ b/onnxslim/core/optimization/dead_node_elimination.py @@ -10,17 +10,18 @@ logger = logging.getLogger("onnxslim") -def dead_node_elimination(graph): +def dead_node_elimination(graph, is_subgraph=False): """Perform in-place constant folding optimizations on the given computational graph by eliminating redundant nodes. """ for subgraph in graph.subgraphs(): - dead_node_elimination(subgraph) + dead_node_elimination(subgraph, is_subgraph=True) for node in graph.nodes: if node.op in {"Identity", "Dropout"}: - delete_node(node) - logger.debug(f"removing {node.op} op: {node.name}") + if not is_subgraph: + delete_node(node) + logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Pad": if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): pad_value = node.inputs[1].values.tolist() diff --git a/onnxslim/core/utils.py b/onnxslim/core/utils.py index d4d53a3..bcbd798 100644 --- a/onnxslim/core/utils.py +++ b/onnxslim/core/utils.py @@ -1,4 +1,4 @@ -from onnxslim.core.pattern import get_node_users +from onnxslim.core.pattern import get_node_users, get_node_feeds from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable @@ -15,10 +15,12 @@ def delete_node(node, input_var_idx=0, output_var_idx=0): break if output_var: - input_node = node.i() - input_node.outputs.remove(node.inputs[input_var_idx]) - input_node.outputs.append(node.outputs[output_var_idx]) - node.outputs.clear() + feeds = get_node_feeds(node) + feed = feeds[0] + if not isinstance(feed, Variable): + feed.outputs.remove(node.inputs[input_var_idx]) + feed.outputs.append(node.outputs[output_var_idx]) + node.outputs.clear() else: for next_node in next_nodes: index = next_node.inputs.index(node_variable)