diff --git a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py index ba3c1e1..813612e 100644 --- a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -287,6 +287,7 @@ def export_graph( for tensor in tensor_map.values() if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues) ] + sparse_initializer = [ OnnxExporter.export_sparse_tensor_proto(tensor) for tensor in tensor_map.values() @@ -356,7 +357,7 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto" ) onnx_graph = OnnxExporter.export_graph( - graph, tensor_map=intersection, subgraph_tensor_map=intersection, do_type_check=do_type_check + graph, tensor_map=graph.tensors(), subgraph_tensor_map=intersection, do_type_check=do_type_check ) onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions] kwargs["functions"] = onnx_functions diff --git a/onnxslim/onnx_graphsurgeon/util/misc.py b/onnxslim/onnx_graphsurgeon/util/misc.py index 99de668..3a06648 100644 --- a/onnxslim/onnx_graphsurgeon/util/misc.py +++ b/onnxslim/onnx_graphsurgeon/util/misc.py @@ -63,6 +63,8 @@ def combine_dicts(dict0, dict1): Values in the second will overwrite values in the first. """ + if dict1 is None: + return dict0 combined = OrderedDict() combined.update(dict0) combined.update(dict1)