Skip to content

Commit

Permalink
fix subgraph bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jul 2, 2024
1 parent 9cde589 commit 6264e09
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions onnxslim/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ def fold_subgraphs():
while index < len(self.nodes):
node = self.nodes[index]
if node.op == "If" and isinstance(node.inputs[0], Constant):
G_LOGGER.debug("Flattening conditional: {:}".format(node))
G_LOGGER.debug("Flattening conditional: {:}".format(node.name))
cond = get_scalar_value(node.inputs[0])
subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"]
# Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors
Expand All @@ -1297,7 +1297,8 @@ def fold_subgraphs():
# The subgraph outputs correspond to the If node outputs. Only the latter are visible
# in the parent graph, so we rebind the producer nodes of the subgraph outputs to point
# to the output tensors of the If instead.
for node_out, subgraph_out in zip(node.outputs, subgraph.outputs):
node_outputs = list(node.outputs)
for node_out, subgraph_out in zip(node_outputs, subgraph.outputs):
node_out.inputs.clear()
for producer in subgraph_out.inputs:
for tensor_idx, out_tensor in enumerate(producer.outputs):
Expand Down

0 comments on commit 6264e09

Please sign in to comment.