diff --git a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py index cc4f329..2f447e0 100644 --- a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -337,19 +337,19 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto" """ sub_graphs = graph.subgraphs(recursive=True) - graph_constants_list = [] - for sub_graph in sub_graphs: - graph_constants = {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)} - graph_constants_list.append(graph_constants) + graph_constants_list = [ + {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)} + for sub_graph in sub_graphs + ] if not graph_constants_list: intersection = None else: - intersection = { - k: graph_constants_list[0][k] - for k in graph_constants_list[0] - if all(k in d for d in graph_constants_list[1:]) - } + intersection = ( + {key: graph_constants_list[0][key] for key in graph_constants_list[0] + if all(key in d and graph_constants_list[0][key] == d[key] for d in graph_constants_list[1:])} + if graph_constants_list else None + ) onnx_graph = OnnxExporter.export_graph( graph, tensor_map=intersection, subgraph_tensor_map=intersection, do_type_check=do_type_check diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index 64ab7c2..58c4637 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -780,7 +780,7 @@ def should_exclude_node(node): for tensor in self.tensors().values(): if len(tensor.inputs) == 1: node = tensor.inputs[0] - if node.op == "Constant": + if node.op == "Constant" and tensor.outputs: if len(node.attrs) != 1: G_LOGGER.warning("Constant node must contain exactly one attribute") continue @@ -1227,7 +1227,7 @@ def should_eval_foldable(tensor): graph_tensors = self.tensors() for name, values in constant_values.items(): tensor = graph_tensors[name] - if isinstance(tensor, Constant): + if isinstance(tensor, Constant) or not tensor.outputs: # No need to fold tensors that are already constant. continue