diff --git a/dargs/dargs.py b/dargs/dargs.py index 98698e7..38bfdf3 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -23,10 +23,10 @@ import re from copy import deepcopy from enum import Enum -from numbers import Real from textwrap import indent -from typing import Any, Callable, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union, get_origin +import typeguard INDENT = " " # doc is indented by four spaces RAW_ANCHOR = False # whether to use raw html anchors or RST ones @@ -176,7 +176,7 @@ def __eq__(self, other: "Argument") -> bool: ) def __repr__(self) -> str: - return f"" + return f"" def __getitem__(self, key: str) -> "Argument": key = key.lstrip("/") @@ -205,10 +205,17 @@ def I(self): return Argument("_", dict, [self]) def _reorg_dtype(self): - if isinstance(self.dtype, type) or self.dtype is None: + if ( + isinstance(self.dtype, type) + or isinstance(get_origin(self.dtype), type) + or self.dtype is None + ): self.dtype = [self.dtype] # remove duplicate - self.dtype = {dt if type(dt) is type else type(dt) for dt in self.dtype} + self.dtype = { + dt if type(dt) is type or type(get_origin(dt)) is type else type(dt) + for dt in self.dtype + } # check conner cases if self.sub_fields or self.sub_variants: self.dtype.add(list if self.repeat else dict) @@ -414,16 +421,19 @@ def _check_exist(self, argdict: dict, path=None): ) def _check_data(self, value: Any, path=None): - if not ( - isinstance(value, self.dtype) - or (float in self.dtype and isinstance(value, Real)) - ): + try: + typeguard.check_type( + value, + self.dtype, + collection_check_strategy=typeguard.CollectionCheckStrategy.ALL_ITEMS, + ) + except typeguard.TypeCheckError as e: raise ArgumentTypeError( path, f"key `{self.name}` gets wrong value type, " - f"requires <{'|'.join(dd.__name__ for dd in self.dtype)}> " - f"but gets <{type(value).__name__}>", - ) + f"requires <{'|'.join(self._get_type_name(dd) for dd in self.dtype)}> " + f"but " + str(e), + ) from e if self.extra_check is not None and not self.extra_check(value): raise ArgumentValueError( path, @@ -586,7 +596,9 @@ def gen_doc(self, path: Optional[List[str]] = None, **kwargs) -> str: return "\n".join(filter(None, doc_list)) def gen_doc_head(self, path: Optional[List[str]] = None, **kwargs) -> str: - typesig = "| type: " + " | ".join([f"``{dt.__name__}``" for dt in self.dtype]) + typesig = "| type: " + " | ".join( + [f"``{self._get_type_name(dt)}``" for dt in self.dtype] + ) if self.optional: typesig += ", optional" if self.default == "": @@ -632,6 +644,10 @@ def gen_doc_body(self, path: Optional[List[str]] = None, **kwargs) -> str: body = "\n".join(body_list) return body + def _get_type_name(self, dd) -> str: + """Get type name for doc/message generation.""" + return str(dd) if isinstance(get_origin(dd), type) else dd.__name__ + class Variant: """Define multiple choices of possible argument sets. @@ -993,6 +1009,8 @@ def default(self, obj) -> Dict[str, Union[str, bool, List]]: "choice_alias": obj.choice_alias, "doc": obj.doc, } + elif isinstance(get_origin(obj), type): + return get_origin(obj).__name__ elif isinstance(obj, type): return obj.__name__ return json.JSONEncoder.default(self, obj) diff --git a/dargs/sphinx.py b/dargs/sphinx.py index 21397b6..80a80e3 100644 --- a/dargs/sphinx.py +++ b/dargs/sphinx.py @@ -192,5 +192,5 @@ def _test_arguments() -> List[Argument]: return [ Argument(name="test1", dtype=int, doc="Argument 1"), Argument(name="test2", dtype=[float, None], doc="Argument 2"), - Argument(name="test3", dtype=list, doc="Argument 3"), + Argument(name="test3", dtype=List[str], doc="Argument 3"), ] diff --git a/pyproject.toml b/pyproject.toml index 223590a..bd31a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ classifiers = [ "License :: OSI Approved :: GNU Lesser General Public License v3 (LGPLv3)", ] dependencies = [ + "typeguard>=3", ] requires-python = ">=3.7" readme = "README.md" diff --git a/tests/test_checker.py b/tests/test_checker.py index b03bad8..088a8b0 100644 --- a/tests/test_checker.py +++ b/tests/test_checker.py @@ -1,3 +1,4 @@ +from typing import List from .context import dargs import unittest from dargs import Argument, Variant @@ -27,6 +28,11 @@ def test_name_type(self): # special handel of int and float ca = Argument("key1", float) ca.check({"key1": 1}) + # list[int] + ca = Argument("key1", List[float]) + ca.check({"key1": [1, 2.0, 3]}) + with self.assertRaises(ArgumentTypeError): + ca.check({"key1": [1, 2.0, "3"]}) # optional case ca = Argument("key1", int, optional=True) ca.check({}) diff --git a/tests/test_docgen.py b/tests/test_docgen.py index 5dc0f37..baf4132 100644 --- a/tests/test_docgen.py +++ b/tests/test_docgen.py @@ -1,6 +1,7 @@ from .context import dargs import unittest import json +from typing import List from dargs import Argument, Variant, ArgumentEncoder @@ -22,6 +23,11 @@ def test_sub_fields(self): [Argument("subsubsub1", int, doc="subsubsub doc." * 5)], doc="subsub doc." * 5, ), + Argument( + "list_of_float", + List[float], + doc="Check if List[float] works.", + ), ], doc="sub doc." * 5, ),