Skip to content

Commit

Permalink
refactor: type check decorator (#489)
Browse files Browse the repository at this point in the history
* use get_annotations instead of signature

Signed-off-by: rokamu623 <[email protected]>

* refact: get kwargs if exists

Signed-off-by: rokamu623 <[email protected]>

* refact: check iterable and dict without e.title

Signed-off-by: rokamu623 <[email protected]>

* refact: rename variable of varargs

Signed-off-by: rokamu623 <[email protected]>

* fix: flake8

Signed-off-by: rokamu623 <[email protected]>

* fix: doc string

Signed-off-by: rokamu623 <[email protected]>

---------

Signed-off-by: rokamu623 <[email protected]>
  • Loading branch information
rokamu623 authored Mar 26, 2024
1 parent 0f2af22 commit 7c66d96
Showing 1 changed file with 44 additions and 63 deletions.
107 changes: 44 additions & 63 deletions src/caret_analyze/common/type_check_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
from __future__ import annotations

from collections.abc import Collection
from collections.abc import Collection, Sequence
from functools import wraps
from inspect import getfullargspec, Signature, signature
from inspect import get_annotations, getfullargspec
from logging import getLogger
from re import findall
from typing import Any
Expand All @@ -27,19 +27,19 @@
from pydantic.deprecated.decorator import validate_arguments

def _get_given_arg(
signature: Signature,
annotations: dict[str, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
given_arg_loc: tuple,
varargs: None | str
varargs_name: None | str
) -> Any:
"""
Get an argument which validation error occurs.
Parameters
----------
signature: Signature
Signature of target function.
annotations: dict[str, Any]
Dict of annotations of target function.
args: tuple[Any, ...]
Arguments of target function.
kwargs: dict[str, Any]
Expand All @@ -53,52 +53,51 @@ def _get_given_arg(
(ii) Dict case
('<ARGUMENT_NAME>', '<KEY>')
varargs: None | str
varargs_name: None | str
The name of the variable length argument if the function has one, otherwise None.
Returns
-------
str
Any
The argument which validation error occurs.
"""
arg_name = given_arg_loc[0]
given_arg: Any = None

# Check kwargs
for k, v in kwargs.items():
if k == arg_name:
given_arg = v
break
if arg_name in kwargs:
given_arg = kwargs.get(arg_name)

if given_arg is None:
# Check args
given_arg_idx = list(signature.parameters.keys()).index(arg_name)
given_arg_idx = list(annotations.keys()).index(arg_name)

# for variable length arguments
if arg_name == varargs:
if arg_name == varargs_name:
given_arg = args[given_arg_idx:]
else:
given_arg = args[given_arg_idx]

return given_arg

def _get_expected_types(e: ValidationError, signature: Signature) -> str:
def _get_expected_types(given_arg_loc: tuple, annotations: dict[str, Any]) -> str:
"""
Get expected types.
Parameters
----------
e: ValidationError
ValidationError instance has one or more ErrorDict instances.
Example of ErrorDict structure is as follows.
(i) Build-in type case:
{'type': 'type_error.<EXPECT_TYPE>', ...}
(ii) Custom class type case:
{'type': 'type_error.arbitrary_type',
'ctx': {'expected_arbitrary_type': '<EXPECT_TYPE>'}, ...}
signature: Signature
Signature of target function.
given_arg_loc: tuple
(i) Not iterable type case
('<ARGUMENT_NAME>,')
(ii) Iterable type except for dict case
('<ARGUMENT_NAME>', '<INDEX>')
(ii) Dict case
('<ARGUMENT_NAME>', '<KEY>')
annotations: dict[str, Any]
Dict of annotations of target function.
Returns
-------
Expand All @@ -110,15 +109,16 @@ def _get_expected_types(e: ValidationError, signature: Signature) -> str:
'<EXPECT_TYPE>'
"""
error = e.errors()[0]
invalid_arg_name: str = str(error['loc'][0])
expected_type: str = str(signature.parameters[invalid_arg_name].annotation)
invalid_arg_name: str = given_arg_loc[0]
expected_type: str = str(annotations[invalid_arg_name])

if e.title == 'IterableArg':
# for list and dict
if 'list[' in expected_type:
expected_type = str(findall(r'.*\[(.*)\]', expected_type)[0])
if e.title == 'DictArg':
if 'dict[' in expected_type:
expected_type = str(findall(r'.*\[.*, (.*)\]', expected_type)[0])

# for union annotations
expected_types: list[str] = expected_type.split(' | ')

if len(expected_types) > 1: # Union case
Expand Down Expand Up @@ -160,19 +160,14 @@ def _get_given_arg_loc_str(given_arg_loc: tuple, given_arg: Any) -> str:
"""
# Iterable or dict type case
if isinstance(given_arg, Collection) or isinstance(given_arg, dict):
if isinstance(given_arg, Sequence) or isinstance(given_arg, dict):
loc_str = f"'{given_arg_loc[0]}'[{given_arg_loc[1]}]"
else:
loc_str = f"'{given_arg_loc[0]}'"

return loc_str

def _get_given_arg_type(
given_arg: Any,
given_arg_loc: tuple,
error_type: str,
varargs: None | str
) -> str:
def _get_given_arg_type(given_arg: Any, given_arg_loc: tuple) -> str:
"""
Get given argument type.
Expand All @@ -189,17 +184,6 @@ def _get_given_arg_type(
(ii) Dict case
('<ARGUMENT_NAME>', '<KEY>')
error_type: str
(i) Dict case
'DictArg'
(ii) Iterable type except for dict case
'IterableArg'
(iii) Not iterable type case
other
varargs: None | str
The name of the variable length argument if the function has one, otherwise None.
Returns
-------
Expand All @@ -214,15 +198,10 @@ def _get_given_arg_type(
Class name input for argument <ARGUMENT_NAME>[<KEY>]
"""
if error_type == 'DictArg':
if isinstance(given_arg, Sequence) or isinstance(given_arg, dict):
given_arg_type_str = f"'{given_arg[given_arg_loc[1]].__class__.__name__}'"
elif error_type == 'IterableArg':
given_arg_type_str = f"'{given_arg[int(given_arg_loc[1])].__class__.__name__}'"
elif varargs is None:
given_arg_type_str = f"'{given_arg.__class__.__name__}'"
else:
# For functions with variable length arguments,
given_arg_type_str = f"'{given_arg[given_arg_loc[1]].__class__.__name__}'"
given_arg_type_str = f"'{given_arg.__class__.__name__}'"

return given_arg_type_str

Expand Down Expand Up @@ -262,19 +241,21 @@ def _custom_wrapper(*args, **kwargs):
try:
# Checks whether the arguments of a given func have variable length arguments
arg_spec = getfullargspec(func)
varargs_name = arg_spec.varargs
arg_len = len(arg_spec.args)
if arg_spec.varargs is not None:

if varargs_name is not None:
args = args[:arg_len] + _parse_collection_or_unpack(args[arg_len:])
return validate_arguments_wrapper(*args, **kwargs)

except ValidationError as e:
loc_tuple = e.errors()[0]['loc']
given_arg = _get_given_arg(signature(func), args, kwargs,
loc_tuple, arg_spec.varargs)
expected_types = _get_expected_types(e, signature(func))
error_type = e.title
annotations = get_annotations(func)

given_arg = _get_given_arg(annotations, args, kwargs, loc_tuple, varargs_name)
expected_types = _get_expected_types(loc_tuple, annotations)
given_arg_loc_str = _get_given_arg_loc_str(loc_tuple, given_arg)
given_arg_type = _get_given_arg_type(given_arg, loc_tuple,
error_type, arg_spec.varargs)
given_arg_type = _get_given_arg_type(given_arg, loc_tuple)

msg = f'Type of argument {given_arg_loc_str} must be {expected_types}. '
msg += f'The given argument type is {given_arg_type}.'
Expand Down

0 comments on commit 7c66d96

Please sign in to comment.