diff --git a/onnxslim/argparser.py b/onnxslim/argparser.py index 769b909..bddc25e 100644 --- a/onnxslim/argparser.py +++ b/onnxslim/argparser.py @@ -7,6 +7,12 @@ import onnxslim +from typing import Union, List, Optional, get_origin, get_args + +def _get_inner_type(arg_type): + if get_origin(arg_type) is Union: + return next((t for t in get_args(arg_type) if t is not type(None)), str) + return arg_type @dataclass class ModelArguments: @@ -126,12 +132,17 @@ def __init__(self, *argument_dataclasses: Type, **kwargs): def _add_arguments(self): for dataclass_type in self.argument_dataclasses: for field_name, field_def in dataclass_type.__dataclass_fields__.items(): - arg_type = field_def.type + arg_type = _get_inner_type(field_def.type) default_value = field_def.default if field_def.default is not field_def.default_factory else None help_text = field_def.metadata.get("help", "") - nargs = "+" if arg_type == Optional[List[str]] else None + nargs = "+" if arg_type == list else None choices = field_def.metadata.get("choices", None) + if choices and default_value is not None and default_value not in choices: + raise ValueError( + f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}." + ) + if arg_type == bool: self.parser.add_argument( f"--{field_name.replace('_', '-')}", @@ -142,7 +153,7 @@ def _add_arguments(self): else: self.parser.add_argument( f"--{field_name.replace('_', '-')}", - type=arg_type if arg_type != Optional[List[str]] else str, + type=arg_type, default=default_value, nargs=nargs, choices=choices,