diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index efff398..88a7890 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -1287,7 +1287,7 @@ def fold_subgraphs(): while index < len(self.nodes): node = self.nodes[index] if node.op == "If" and isinstance(node.inputs[0], Constant): - G_LOGGER.debug("Flattening conditional: {:}".format(node)) + G_LOGGER.debug("Flattening conditional: {:}".format(node.name)) cond = get_scalar_value(node.inputs[0]) subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"] # Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors @@ -1297,7 +1297,8 @@ def fold_subgraphs(): # The subgraph outputs correspond to the If node outputs. Only the latter are visible # in the parent graph, so we rebind the producer nodes of the subgraph outputs to point # to the output tensors of the If instead. - for node_out, subgraph_out in zip(node.outputs, subgraph.outputs): + node_outputs = list(node.outputs) + for node_out, subgraph_out in zip(node_outputs, subgraph.outputs): node_out.inputs.clear() for producer in subgraph_out.inputs: for tensor_idx, out_tensor in enumerate(producer.outputs):