Skip to content

Commit

Permalink
fix: mypy warning in type_check_decorator by signiture (#312)
Browse files Browse the repository at this point in the history
* fix: mypy warning

Signed-off-by: rokamu623 <[email protected]>

* fix: type_check_decorator and those tests

Signed-off-by: rokamu623 <[email protected]>

* fix: flake8

Signed-off-by: rokamu623 <[email protected]>

* fix: expected type by annotation

Signed-off-by: rokamu623 <[email protected]>

* fix: adopt error type

Signed-off-by: rokamu623 <[email protected]>

* fix: doc string

Signed-off-by: rokamu623 <[email protected]>

* fix: mypy

Signed-off-by: rokamu623 <[email protected]>

* fix: mypy

Signed-off-by: rokamu623 <[email protected]>

* fix: adapt union case in itrator

Signed-off-by: rokamu623 <[email protected]>

* feat: comment for future works

Signed-off-by: rokamu623 <[email protected]>

---------

Signed-off-by: rokamu623 <[email protected]>
  • Loading branch information
rokamu623 authored Aug 10, 2023
1 parent 70e3822 commit 504ad62
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 18 deletions.
63 changes: 45 additions & 18 deletions src/caret_analyze/common/type_check_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from functools import wraps
from inspect import Signature, signature
from re import findall
from typing import Any

from ..exceptions import UnsupportedTypeError
Expand All @@ -24,7 +25,7 @@
try:
from pydantic import validate_arguments, ValidationError

def _get_expected_types(e: ValidationError) -> str:
def _get_expected_types(e: ValidationError, signature: Signature) -> str:
"""
Get expected types.
Expand All @@ -39,6 +40,8 @@ def _get_expected_types(e: ValidationError) -> str:
(ii) Custom class type case:
{'type': 'type_error.arbitrary_type',
'ctx': {'expected_arbitrary_type': '<EXPECT_TYPE>'}, ...}
signature: Signature
Signature of target function.
Returns
-------
Expand All @@ -50,12 +53,16 @@ def _get_expected_types(e: ValidationError) -> str:
'<EXPECT_TYPE>'
"""
expected_types: list[str] = []
for error in e.errors():
if error['type'] == 'type_error.arbitrary_type': # Custom class type case
expected_types.append(error['ctx']['expected_arbitrary_type'])
else:
expected_types.append(error['type'].replace('type_error.', ''))
error = e.errors()[0]
invalid_arg_name: str = str(error['loc'][0])
expected_type: str = str(signature.parameters[invalid_arg_name].annotation)

if e.title == 'IterableArg':
expected_type = str(findall(r'.*\[(.*)\]', expected_type)[0])
if e.title == 'DictArg':
expected_type = str(findall(r'.*\[.*, (.*)\]', expected_type)[0])

expected_types: list[str] = expected_type.split(' | ')

if len(expected_types) > 1: # Union case
expected_types_str = str(expected_types)
Expand All @@ -64,7 +71,7 @@ def _get_expected_types(e: ValidationError) -> str:

return expected_types_str

def _get_given_arg_loc_str(given_arg_loc: tuple) -> str:
def _get_given_arg_loc_str(given_arg_loc: tuple, error_type: str) -> str:
"""
Get given argument location string.
Expand All @@ -79,6 +86,15 @@ def _get_given_arg_loc_str(given_arg_loc: tuple) -> str:
(ii) Dict case
('<ARGUMENT_NAME>', '<KEY>')
error_type: str
(i) Dict case
'DictArg'
(ii) Iterable type except for dict case
'IterableArg'
(iii) Not iterable type case
other
Returns
-------
Expand All @@ -93,7 +109,7 @@ def _get_given_arg_loc_str(given_arg_loc: tuple) -> str:
'<ARGUMENT_NAME>'[KEY]
"""
if len(given_arg_loc) == 2: # Iterable type case
if error_type == 'IterableArg' or error_type == 'DictArg': # Iterable type case
loc_str = f"'{given_arg_loc[0]}'[{given_arg_loc[1]}]"
else:
loc_str = f"'{given_arg_loc[0]}'"
Expand All @@ -104,7 +120,8 @@ def _get_given_arg_type(
signature: Signature,
args: tuple[Any, ...],
kwargs: dict[str, Any],
given_arg_loc: tuple
given_arg_loc: tuple,
error_type: str
) -> str:
"""
Get given argument type.
Expand All @@ -126,6 +143,15 @@ def _get_given_arg_type(
(ii) Dict case
('<ARGUMENT_NAME>', '<KEY>')
error_type: str
(i) Dict case
'DictArg'
(ii) Iterable type except for dict case
'IterableArg'
(iii) Not iterable type case
other
Returns
-------
Expand Down Expand Up @@ -154,11 +180,10 @@ def _get_given_arg_type(
given_arg_idx = list(signature.parameters.keys()).index(arg_name)
given_arg = args[given_arg_idx]

if len(given_arg_loc) == 2: # Iterable type case
if isinstance(given_arg, dict):
given_arg_type_str = f"'{given_arg[given_arg_loc[1]].__class__.__name__}'"
else:
given_arg_type_str = f"'{given_arg[int(given_arg_loc[1])].__class__.__name__}'"
if error_type == 'DictArg':
given_arg_type_str = f"'{given_arg[given_arg_loc[1]].__class__.__name__}'"
elif error_type == 'IterableArg':
given_arg_type_str = f"'{given_arg[int(given_arg_loc[1])].__class__.__name__}'"
else:
given_arg_type_str = f"'{given_arg.__class__.__name__}'"

Expand All @@ -173,10 +198,12 @@ def _custom_wrapper(*args, **kwargs):
try:
return validate_arguments_wrapper(*args, **kwargs)
except ValidationError as e:
expected_types = _get_expected_types(e)
expected_types = _get_expected_types(e, signature(func))
error_type = e.title
loc_tuple = e.errors()[0]['loc']
given_arg_loc_str = _get_given_arg_loc_str(loc_tuple)
given_arg_type = _get_given_arg_type(signature(func), args, kwargs, loc_tuple)
given_arg_loc_str = _get_given_arg_loc_str(loc_tuple, error_type)
given_arg_type \
= _get_given_arg_type(signature(func), args, kwargs, loc_tuple, error_type)

msg = f'Type of argument {given_arg_loc_str} must be {expected_types}. '
msg += f'The given argument type is {given_arg_type}.'
Expand Down
30 changes: 30 additions & 0 deletions src/test/common/test_type_check_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ def iterable_arg(i: list[bool]):
iterable_arg([True, 10])
assert "'i'[1] must be 'bool'. The given argument type is 'int'" in str(e.value)

def test_type_check_decorator_iterable_with_union(self):
@type_check_decorator
def iterable_arg(i: list[bool | str]):
pass

with pytest.raises(UnsupportedTypeError) as e:
iterable_arg([True, 10])
assert "'i'[1] must be ['bool', 'str']. The given argument type is 'int'" in str(e.value)

# TODO: test_type_check_decorator_union_with_iterable
# @type_check_decorator
# def iterable_arg(i: list[bool] | str):
# pass

def test_type_check_decorator_dict(self):
@type_check_decorator
def dict_arg(d: dict[str, bool]):
Expand All @@ -75,6 +89,22 @@ def dict_arg(d: dict[str, bool]):
'key2': 10})
assert "'d'[key2] must be 'bool'. The given argument type is 'int'" in str(e.value)

# TODO: test_type_check_decorator_dict_key
# with pytest.raises(UnsupportedTypeError) as e:
# dict_arg({'key1': True,
# 1: 10})

def test_type_check_decorator_dict_with_union(self):
@type_check_decorator
def dict_arg(d: dict[str, bool | str]):
pass

with pytest.raises(UnsupportedTypeError) as e:
dict_arg({'key1': True,
'key2': 10})
assert "'d'[key2] must be ['bool', 'str']. The given argument type is 'int'"\
in str(e.value)

def test_type_check_decorator_kwargs(self):
@type_check_decorator
def kwarg(k: bool):
Expand Down

0 comments on commit 504ad62

Please sign in to comment.