diff --git a/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py index 3c021aa..94671a4 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -363,6 +363,9 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto" if "opset_imports" not in kwargs: kwargs["opset_imports"] = update_import_domains(graph) + if "ir_version" not in kwargs and graph.ir_version is not None: + kwargs["ir_version"] = graph.ir_version + model = onnx.helper.make_model(onnx_graph, **kwargs) model.producer_name = graph.producer_name model.producer_version = graph.producer_version diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py index c1dcbbf..26e3d8d 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py @@ -210,6 +210,14 @@ def get_import_domains(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto] """Retrieves the import domains from an ONNX model or function.""" return model_or_func.opset_import + @staticmethod + def get_ir_version(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): + """Retrieves the ir_version from an ONNX model or function.""" + try: + return model_or_func.ir_version + except Exception: + return None + @staticmethod def import_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto]) -> Tensor: """Converts an ONNX tensor into a corresponding internal Tensor representation.""" @@ -397,6 +405,7 @@ def import_graph( tensor_map: "OrderedDict[str, Tensor]" = None, opset=None, import_domains: onnx.OperatorSetIdProto = None, + ir_version=None, producer_name: str = None, producer_version: str = None, functions: List[Function] = None, @@ -495,6 +504,7 @@ def get_tensor( producer_version=producer_version, opset=opset, import_domains=import_domains, + ir_version=ir_version, functions=functions, ) @@ -510,6 +520,7 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph: Graph: A corresponding onnx-graphsurgeon Graph. """ model_opset = OnnxImporter.get_opset(onnx_model) + model_ir_version = OnnxImporter.get_ir_version(onnx_model) model_import_domains = OnnxImporter.get_import_domains(onnx_model) functions: List[Function] = [ OnnxImporter.import_function( @@ -534,6 +545,7 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph: onnx_model.graph, opset=model_opset, import_domains=model_import_domains, + ir_version=model_ir_version, producer_name=onnx_model.producer_name, producer_version=onnx_model.producer_version, functions=functions, diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py index 414c47d..4d5bcf3 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/graph.py @@ -106,6 +106,7 @@ def __init__( doc_string=None, opset=None, import_domains=None, + ir_version=None, producer_name: str = None, producer_version: str = None, functions: "Sequence[Function]" = None, @@ -133,6 +134,7 @@ def __init__( self.producer_name = misc.default_value(producer_name, "") self.producer_version = misc.default_value(producer_version, "") self.import_domains = import_domains + self.ir_version = ir_version # For layer() function self.name_idx = 0 diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 561114b..2da545c 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -177,11 +177,13 @@ def format_model_info(model_name: str, model_info_list: List[Dict], elapsed_time ["Model Name", model_name], [SEPARATING_LINE], ["Op Set ", model_info_list[0]["op_set"]], + [SEPARATING_LINE], + ["IR Version ", model_info_list[0]["ir_version"]], ) ) else: final_op_info.append( - ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"]] + [""] * (len(model_info_list) - 2) + ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"] + " / IR Version: " + model_info_list[0]["ir_version"]] + [""] * (len(model_info_list) - 2) ) final_op_info.extend( ( @@ -319,6 +321,12 @@ def get_opset(model: onnx.ModelProto) -> int: except Exception: return None +def get_ir_version(model: onnx.ModelProto) -> int: + """Returns the ONNX ir version for a given model.""" + try: + return model.ir_version + except Exception: + return None def summarize_model(model: Union[str, onnx.ModelProto], tag=None) -> Dict: """Generates a summary of the ONNX model, including model size, operations, and tensor shapes.""" @@ -388,6 +396,7 @@ def get_graph_node_info(graph: onnx.GraphProto) -> Dict[str, List[str]]: get_graph_node_info(model.graph) model_info["op_set"] = str(get_opset(model)) + model_info["ir_version"] = str(get_ir_version(model)) model_info["op_info"] = op_info model_info["op_type_counts"] = op_type_counts