diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index bddc25e..be0c49e 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -1,19 +1,18 @@ -import sys import argparse -from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter import dataclasses +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from dataclasses import dataclass, field -from typing import List, Optional, Type +from typing import List, Optional, Type, Union, get_args, get_origin import onnxslim -from typing import Union, List, Optional, get_origin, get_args def _get_inner_type(arg_type): if get_origin(arg_type) is Union: return next((t for t in get_args(arg_type) if t is not type(None)), str) return arg_type + @dataclass class ModelArguments: """ diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index f02460f..79fbe3a 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -123,14 +123,16 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs): def main(): """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function.""" from onnxslim.argparser import ( - OnnxSlimArgumentParser, CheckerArguments, ModelArguments, ModificationArguments, + OnnxSlimArgumentParser, OptimizationArguments, ) - argument_parser = OnnxSlimArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments) + argument_parser = OnnxSlimArgumentParser( + ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments + ) model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses() if checker_args.inspect and model_args.output_model: diff --git a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py index b05c055..723f270 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py @@ -551,5 +551,5 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph: producer_name=onnx_model.producer_name, producer_version=onnx_model.producer_version, functions=functions, - metadata_props=onnx_model.metadata_props + metadata_props=onnx_model.metadata_props, )