Skip to content

Commit

Permalink
improve argparser
Browse files Browse the repository at this point in the history
  • Loading branch information
whyb committed Nov 16, 2024
1 parent c492ffd commit 3bb90f1
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions onnxslim/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

import onnxslim

from typing import Union, List, Optional, get_origin, get_args

def _get_inner_type(arg_type):
if get_origin(arg_type) is Union:
return next((t for t in get_args(arg_type) if t is not type(None)), str)
return arg_type

@dataclass
class ModelArguments:
Expand Down Expand Up @@ -126,12 +132,17 @@ def __init__(self, *argument_dataclasses: Type, **kwargs):
def _add_arguments(self):
for dataclass_type in self.argument_dataclasses:
for field_name, field_def in dataclass_type.__dataclass_fields__.items():
arg_type = field_def.type
arg_type = _get_inner_type(field_def.type)
default_value = field_def.default if field_def.default is not field_def.default_factory else None
help_text = field_def.metadata.get("help", "")
nargs = "+" if arg_type == Optional[List[str]] else None
nargs = "+" if arg_type == list else None
choices = field_def.metadata.get("choices", None)

if choices and default_value is not None and default_value not in choices:
raise ValueError(
f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
)

if arg_type == bool:
self.parser.add_argument(
f"--{field_name.replace('_', '-')}",
Expand All @@ -142,7 +153,7 @@ def _add_arguments(self):
else:
self.parser.add_argument(
f"--{field_name.replace('_', '-')}",
type=arg_type if arg_type != Optional[List[str]] else str,
type=arg_type,
default=default_value,
nargs=nargs,
choices=choices,
Expand Down

0 comments on commit 3bb90f1

Please sign in to comment.