diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py index e87fe76..b1cef6c 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py @@ -19,7 +19,7 @@ from dataclasses import dataclass from typing import Dict, List, Union -from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor +from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER from onnxslim.third_party.onnx_graphsurgeon.util import misc @@ -244,8 +244,28 @@ def replace_all_uses_with(self, node: Union["Node", "Tensor"], input_var_idx=0, input_var = node.outputs[output_var_idx] else: input_var = node + + output_var = None for output in self.outputs: - for node_ in output.outputs: - index = node_.inputs.index(output) - node_.inputs.pop(index) - node_.inputs.insert(index, input_var) + if isinstance(output, Variable) and output.is_output: + output_var = output + break + + if output_var: + feed = self.feeds[0] + if not isinstance(feed, (Variable, Constant)): + index = feed.outputs.index(self.inputs[input_var_idx]) + feed.outputs.pop(index) + feed.outputs.insert(index, self.outputs[output_var_idx]) + for user in list(self.inputs[input_var_idx].outputs): + # do not use index here, because index will only return the first index of the input + for i, input in enumerate(user.inputs): + if input == self.inputs[input_var_idx]: + user.inputs[i] = self.outputs[output_var_idx] + self.outputs.clear() + else: + for output in self.outputs: + for node_ in output.outputs: + index = node_.inputs.index(output) + node_.inputs.pop(index) + node_.inputs.insert(index, input_var)