Skip to content

Commit

Permalink
fix subgraph constant folding bug
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 13, 2024
1 parent f455f25 commit ee38358
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,19 +337,19 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto"
"""
sub_graphs = graph.subgraphs(recursive=True)

graph_constants_list = []
for sub_graph in sub_graphs:
graph_constants = {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)}
graph_constants_list.append(graph_constants)
graph_constants_list = [
{name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)}
for sub_graph in sub_graphs
]

if not graph_constants_list:
intersection = None
else:
intersection = {
k: graph_constants_list[0][k]
for k in graph_constants_list[0]
if all(k in d for d in graph_constants_list[1:])
}
intersection = (
{key: graph_constants_list[0][key] for key in graph_constants_list[0]
if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])}
if graph_constants_list else None
)

onnx_graph = OnnxExporter.export_graph(
graph, tensor_map=intersection, subgraph_tensor_map=intersection, do_type_check=do_type_check
Expand Down
4 changes: 2 additions & 2 deletions onnxslim/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def should_exclude_node(node):
for tensor in self.tensors().values():
if len(tensor.inputs) == 1:
node = tensor.inputs[0]
if node.op == "Constant":
if node.op == "Constant" and tensor.outputs:
if len(node.attrs) != 1:
G_LOGGER.warning("Constant node must contain exactly one attribute")
continue
Expand Down Expand Up @@ -1227,7 +1227,7 @@ def should_eval_foldable(tensor):
graph_tensors = self.tensors()
for name, values in constant_values.items():
tensor = graph_tensors[name]
if isinstance(tensor, Constant):
if isinstance(tensor, Constant) or not tensor.outputs:
# No need to fold tensors that are already constant.
continue

Expand Down

0 comments on commit ee38358

Please sign in to comment.