Skip to content

Commit

Permalink
fix when node is output
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Dec 12, 2024
1 parent 8402247 commit e217e3d
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions onnxslim/third_party/onnx_graphsurgeon/ir/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dataclasses import dataclass
from typing import Dict, List, Union

from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
from onnxslim.third_party.onnx_graphsurgeon.util import misc

Expand Down Expand Up @@ -244,8 +244,28 @@ def replace_all_uses_with(self, node: Union["Node", "Tensor"], input_var_idx=0,
input_var = node.outputs[output_var_idx]
else:
input_var = node

output_var = None
for output in self.outputs:
for node_ in output.outputs:
index = node_.inputs.index(output)
node_.inputs.pop(index)
node_.inputs.insert(index, input_var)
if isinstance(output, Variable) and output.is_output:
output_var = output
break

if output_var:
feed = self.feeds[0]
if not isinstance(feed, (Variable, Constant)):
index = feed.outputs.index(self.inputs[input_var_idx])
feed.outputs.pop(index)
feed.outputs.insert(index, self.outputs[output_var_idx])
for user in list(self.inputs[input_var_idx].outputs):
# do not use index here, because index will only return the first index of the input
for i, input in enumerate(user.inputs):
if input == self.inputs[input_var_idx]:
user.inputs[i] = self.outputs[output_var_idx]
self.outputs.clear()
else:
for output in self.outputs:
for node_ in output.outputs:
index = node_.inputs.index(output)
node_.inputs.pop(index)
node_.inputs.insert(index, input_var)

0 comments on commit e217e3d

Please sign in to comment.