Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Nov 16, 2024
1 parent 3bb90f1 commit 46efa87
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
7 changes: 3 additions & 4 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import sys
import argparse
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import dataclasses
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from dataclasses import dataclass, field
from typing import List, Optional, Type
from typing import List, Optional, Type, Union, get_args, get_origin

import onnxslim

from typing import Union, List, Optional, get_origin, get_args

def _get_inner_type(arg_type):
if get_origin(arg_type) is Union:
return next((t for t in get_args(arg_type) if t is not type(None)), str)
return arg_type


@dataclass
class ModelArguments:
"""
Expand Down
6 changes: 4 additions & 2 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,16 @@ def slim(model: Union[str, onnx.ModelProto], *args, **kwargs):
def main():
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
from onnxslim.argparser import (
OnnxSlimArgumentParser,
CheckerArguments,
ModelArguments,
ModificationArguments,
OnnxSlimArgumentParser,
OptimizationArguments,
)

argument_parser = OnnxSlimArgumentParser(ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments)
argument_parser = OnnxSlimArgumentParser(
ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments
)
model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()

if checker_args.inspect and model_args.output_model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,5 +551,5 @@ def import_onnx(onnx_model: "onnx.ModelProto") -> Graph:
producer_name=onnx_model.producer_name,
producer_version=onnx_model.producer_version,
functions=functions,
metadata_props=onnx_model.metadata_props
metadata_props=onnx_model.metadata_props,
)

0 comments on commit 46efa87

Please sign in to comment.