Skip to content

Commit

Permalink
fix ir version
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Nov 10, 2024
1 parent 97bacb7 commit 923c426
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto"
if "opset_imports" not in kwargs:
kwargs["opset_imports"] = update_import_domains(graph)

if "ir_version" not in kwargs and graph.ir_version is not None:
kwargs["ir_version"] = graph.ir_version

model = onnx.helper.make_model(onnx_graph, **kwargs)
model.producer_name = graph.producer_name
model.producer_version = graph.producer_version
Expand Down
12 changes: 12 additions & 0 deletions onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ def get_import_domains(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]
"""Retrieves the import domains from an ONNX model or function."""
return model_or_func.opset_import

@staticmethod
def get_ir_version(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]):
"""Retrieves the ir_version from an ONNX model or function."""
try:
return model_or_func.ir_version
except Exception:
return None

@staticmethod
def import_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto]) -> Tensor:
"""Converts an ONNX tensor into a corresponding internal Tensor representation."""
Expand Down Expand Up @@ -397,6 +405,7 @@ def import_graph(
tensor_map: "OrderedDict[str, Tensor]" = None,
opset=None,
import_domains: onnx.OperatorSetIdProto = None,
ir_version=None,
producer_name: str = None,
producer_version: str = None,
functions: List[Function] = None,
Expand Down Expand Up @@ -495,6 +504,7 @@ def get_tensor(
producer_version=producer_version,
opset=opset,
import_domains=import_domains,
ir_version=ir_version,
functions=functions,
)

Expand All @@ -510,6 +520,7 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph:
Graph: A corresponding onnx-graphsurgeon Graph.
"""
model_opset = OnnxImporter.get_opset(onnx_model)
model_ir_version = OnnxImporter.get_ir_version(onnx_model)
model_import_domains = OnnxImporter.get_import_domains(onnx_model)
functions: List[Function] = [
OnnxImporter.import_function(
Expand All @@ -534,6 +545,7 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph:
onnx_model.graph,
opset=model_opset,
import_domains=model_import_domains,
ir_version=model_ir_version,
producer_name=onnx_model.producer_name,
producer_version=onnx_model.producer_version,
functions=functions,
Expand Down
2 changes: 2 additions & 0 deletions onnxslim/third_party/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
doc_string=None,
opset=None,
import_domains=None,
ir_version=None,
producer_name: str = None,
producer_version: str = None,
functions: "Sequence[Function]" = None,
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
self.producer_name = misc.default_value(producer_name, "")
self.producer_version = misc.default_value(producer_version, "")
self.import_domains = import_domains
self.ir_version = ir_version
# For layer() function
self.name_idx = 0

Expand Down
11 changes: 10 additions & 1 deletion onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time
["Model Name", model_name],
[SEPARATING_LINE],
["Op Set ", model_info_list[0]["op_set"]],
[SEPARATING_LINE],
["IR Version ", model_info_list[0]["ir_version"]],
)
)
else:
final_op_info.append(
["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"]] + [""] * (len(model_info_list) - 2)
["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"] + " / IR Version: " + model_info_list[0]["ir_version"]] + [""] * (len(model_info_list) - 2)
)
final_op_info.extend(
(
Expand Down Expand Up @@ -319,6 +321,12 @@ def get_opset(model: onnx.ModelProto) -> int:
except Exception:
return None

def get_ir_version(model: onnx.ModelProto) -> int:
"""Returns the ONNX ir version for a given model."""
try:
return model.ir_version
except Exception:
return None

def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict:
"""Generates a summary of the ONNX model, including model size, operations, and tensor shapes."""
Expand Down Expand Up @@ -388,6 +396,7 @@ def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]:
get_graph_node_info(model.graph)

model_info["op_set"] = str(get_opset(model))
model_info["ir_version"] = str(get_ir_version(model))
model_info["op_info"] = op_info
model_info["op_type_counts"] = op_type_counts

Expand Down

0 comments on commit 923c426

Please sign in to comment.