Skip to content

Commit

Permalink
add metadata support
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 12, 2024
1 parent b1ca25f commit 8c82f56
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto"
kwargs["ir_version"] = graph.ir_version

model = onnx.helper.make_model(onnx_graph, **kwargs)
model.metadata_props.extend(graph.metadata_props)
model.producer_name = graph.producer_name
model.producer_version = graph.producer_version
return model
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def import_graph(
producer_name: str = None,
producer_version: str = None,
functions: List[Function] = None,
metadata_props=None,
) -> Graph:
"""
Imports a Graph from an ONNX Graph.
Expand Down Expand Up @@ -506,6 +507,7 @@ def get_tensor(
import_domains=import_domains,
ir_version=ir_version,
functions=functions,
metadata_props=metadata_props,
)


Expand Down Expand Up @@ -549,4 +551,5 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph:
producer_name=onnx_model.producer_name,
producer_version=onnx_model.producer_version,
functions=functions,
metadata_props=onnx_model.metadata_props
)
2 changes: 2 additions & 0 deletions onnxslim/third_party/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
producer_name: str = None,
producer_version: str = None,
functions: "Sequence[Function]" = None,
metadata_props=None,
):
"""
Args:
Expand All @@ -133,6 +134,7 @@ def __init__(
self.opset = misc.default_value(opset, Graph.DEFAULT_OPSET)
self.producer_name = misc.default_value(producer_name, "")
self.producer_version = misc.default_value(producer_version, "")
self.metadata_props = metadata_props
self.import_domains = import_domains
self.ir_version = ir_version
# For layer() function
Expand Down

0 comments on commit 8c82f56

Please sign in to comment.