Skip to content

Commit

Permalink
support types in typing module
Browse files Browse the repository at this point in the history
>>> ca = Argument("key1", List[float])
>>> ca.check({"key1": [1, 2.0, 3]})
pass
>>> ca.check({"key1": [1, 2.0, "3"]})
throw ArgumentTypeError

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 23, 2023
1 parent aef1574 commit c9a3715
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
23 changes: 12 additions & 11 deletions dargs/dargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import re
from copy import deepcopy
from enum import Enum
from numbers import Real
from textwrap import indent
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_origin

import typeguard

INDENT = " " # doc is indented by four spaces
RAW_ANCHOR = False # whether to use raw html anchors or RST ones
Expand Down Expand Up @@ -205,10 +205,10 @@ def I(self):
return Argument("_", dict, [self])

def _reorg_dtype(self):
if isinstance(self.dtype, type) or self.dtype is None:
if isinstance(self.dtype, type) or isinstance(get_origin(self.dtype), type) or self.dtype is None:
self.dtype = [self.dtype]
# remove duplicate
self.dtype = {dt if type(dt) is type else type(dt) for dt in self.dtype}
self.dtype = {dt if type(dt) is type or type(get_origin(dt)) is type else type(dt) for dt in self.dtype}
# check conner cases
if self.sub_fields or self.sub_variants:
self.dtype.add(list if self.repeat else dict)
Expand Down Expand Up @@ -414,16 +414,15 @@ def _check_exist(self, argdict: dict, path=None):
)

def _check_data(self, value: Any, path=None):
if not (
isinstance(value, self.dtype)
or (float in self.dtype and isinstance(value, Real))
):
try:
typeguard.check_type(value, self.dtype, collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS)
except typeguard.TypeCheckError as e:
raise ArgumentTypeError(
path,
f"key `{self.name}` gets wrong value type, "
f"requires <{'|'.join(dd.__name__ for dd in self.dtype)}> "
f"but gets <{type(value).__name__}>",
)
f"requires <{'|'.join(str(dd) if isinstance(get_origin(dd), type) else dd.__name__ for dd in self.dtype)}> "
f"but " + str(e),
) from e
if self.extra_check is not None and not self.extra_check(value):
raise ArgumentValueError(
path,
Expand Down Expand Up @@ -993,6 +992,8 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]:
"choice_alias": obj.choice_alias,
"doc": obj.doc,
}
elif isinstance(get_origin(obj), type):
return str(obj)
elif isinstance(obj, type):
return obj.__name__
return json.JSONEncoder.default(self, obj)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ classifiers = [
"License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)",
]
dependencies = [
"typeguard>=3",
]
requires-python = ">=3.7"
readme = "README.md"
Expand Down
6 changes: 6 additions & 0 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from .context import dargs
import unittest
from dargs import Argument, Variant
Expand Down Expand Up @@ -27,6 +28,11 @@ def test_name_type(self):
# special handel of int and float
ca = Argument("key1", float)
ca.check({"key1": 1})
# list[int]
ca = Argument("key1", List[float])
ca.check({"key1": [1, 2.0, 3]})
with self.assertRaises(ArgumentTypeError):
ca.check({"key1": [1, 2.0, "3"]})
# optional case
ca = Argument("key1", int, optional=True)
ca.check({})
Expand Down
6 changes: 6 additions & 0 deletions tests/test_docgen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .context import dargs
import unittest
import json
from typing import List
from dargs import Argument, Variant, ArgumentEncoder


Expand All @@ -22,6 +23,11 @@ def test_sub_fields(self):
[Argument("subsubsub1", int, doc="subsubsub doc." * 5)],
doc="subsub doc." * 5,
),
Argument(
"list_of_float",
List[float],
doc="Check if List[float] works.",
),
],
doc="sub doc." * 5,
),
Expand Down

0 comments on commit c9a3715

Please sign in to comment.