diff --git a/src/caret_analyze/common/type_check_decorator.py b/src/caret_analyze/common/type_check_decorator.py index 5f319a137..d897b3bb4 100644 --- a/src/caret_analyze/common/type_check_decorator.py +++ b/src/caret_analyze/common/type_check_decorator.py @@ -16,6 +16,7 @@ from functools import wraps from inspect import Signature, signature +from re import findall from typing import Any from ..exceptions import UnsupportedTypeError @@ -24,7 +25,7 @@ try: from pydantic import validate_arguments, ValidationError - def _get_expected_types(e: ValidationError) -> str: + def _get_expected_types(e: ValidationError, signature: Signature) -> str: """ Get expected types. @@ -39,6 +40,8 @@ def _get_expected_types(e: ValidationError) -> str: (ii) Custom class type case: {'type': 'type_error.arbitrary_type', 'ctx': {'expected_arbitrary_type': ''}, ...} + signature: Signature + Signature of target function. Returns ------- @@ -50,12 +53,16 @@ def _get_expected_types(e: ValidationError) -> str: '' """ - expected_types: list[str] = [] - for error in e.errors(): - if error['type'] == 'type_error.arbitrary_type': # Custom class type case - expected_types.append(error['ctx']['expected_arbitrary_type']) - else: - expected_types.append(error['type'].replace('type_error.', '')) + error = e.errors()[0] + invalid_arg_name: str = str(error['loc'][0]) + expected_type: str = str(signature.parameters[invalid_arg_name].annotation) + + if e.title == 'IterableArg': + expected_type = str(findall(r'.*\[(.*)\]', expected_type)[0]) + if e.title == 'DictArg': + expected_type = str(findall(r'.*\[.*, (.*)\]', expected_type)[0]) + + expected_types: list[str] = expected_type.split(' | ') if len(expected_types) > 1: # Union case expected_types_str = str(expected_types) @@ -64,7 +71,7 @@ def _get_expected_types(e: ValidationError) -> str: return expected_types_str - def _get_given_arg_loc_str(given_arg_loc: tuple) -> str: + def _get_given_arg_loc_str(given_arg_loc: tuple, error_type: str) -> str: """ Get given argument location string. @@ -79,6 +86,15 @@ def _get_given_arg_loc_str(given_arg_loc: tuple) -> str: (ii) Dict case ('', '') + error_type: str + (i) Dict case + 'DictArg' + + (ii) Iterable type except for dict case + 'IterableArg' + + (iii) Not iterable type case + other Returns ------- @@ -93,7 +109,7 @@ def _get_given_arg_loc_str(given_arg_loc: tuple) -> str: ''[KEY] """ - if len(given_arg_loc) == 2: # Iterable type case + if error_type == 'IterableArg' or error_type == 'DictArg': # Iterable type case loc_str = f"'{given_arg_loc[0]}'[{given_arg_loc[1]}]" else: loc_str = f"'{given_arg_loc[0]}'" @@ -104,7 +120,8 @@ def _get_given_arg_type( signature: Signature, args: tuple[Any, ...], kwargs: dict[str, Any], - given_arg_loc: tuple + given_arg_loc: tuple, + error_type: str ) -> str: """ Get given argument type. @@ -126,6 +143,15 @@ def _get_given_arg_type( (ii) Dict case ('', '') + error_type: str + (i) Dict case + 'DictArg' + + (ii) Iterable type except for dict case + 'IterableArg' + + (iii) Not iterable type case + other Returns ------- @@ -154,11 +180,10 @@ def _get_given_arg_type( given_arg_idx = list(signature.parameters.keys()).index(arg_name) given_arg = args[given_arg_idx] - if len(given_arg_loc) == 2: # Iterable type case - if isinstance(given_arg, dict): - given_arg_type_str = f"'{given_arg[given_arg_loc[1]].__class__.__name__}'" - else: - given_arg_type_str = f"'{given_arg[int(given_arg_loc[1])].__class__.__name__}'" + if error_type == 'DictArg': + given_arg_type_str = f"'{given_arg[given_arg_loc[1]].__class__.__name__}'" + elif error_type == 'IterableArg': + given_arg_type_str = f"'{given_arg[int(given_arg_loc[1])].__class__.__name__}'" else: given_arg_type_str = f"'{given_arg.__class__.__name__}'" @@ -173,10 +198,12 @@ def _custom_wrapper(*args, **kwargs): try: return validate_arguments_wrapper(*args, **kwargs) except ValidationError as e: - expected_types = _get_expected_types(e) + expected_types = _get_expected_types(e, signature(func)) + error_type = e.title loc_tuple = e.errors()[0]['loc'] - given_arg_loc_str = _get_given_arg_loc_str(loc_tuple) - given_arg_type = _get_given_arg_type(signature(func), args, kwargs, loc_tuple) + given_arg_loc_str = _get_given_arg_loc_str(loc_tuple, error_type) + given_arg_type \ + = _get_given_arg_type(signature(func), args, kwargs, loc_tuple, error_type) msg = f'Type of argument {given_arg_loc_str} must be {expected_types}. ' msg += f'The given argument type is {given_arg_type}.' diff --git a/src/test/common/test_type_check_decorator.py b/src/test/common/test_type_check_decorator.py index fa8f923eb..353105918 100644 --- a/src/test/common/test_type_check_decorator.py +++ b/src/test/common/test_type_check_decorator.py @@ -65,6 +65,20 @@ def iterable_arg(i: list[bool]): iterable_arg([True, 10]) assert "'i'[1] must be 'bool'. The given argument type is 'int'" in str(e.value) + def test_type_check_decorator_iterable_with_union(self): + @type_check_decorator + def iterable_arg(i: list[bool | str]): + pass + + with pytest.raises(UnsupportedTypeError) as e: + iterable_arg([True, 10]) + assert "'i'[1] must be ['bool', 'str']. The given argument type is 'int'" in str(e.value) + + # TODO: test_type_check_decorator_union_with_iterable + # @type_check_decorator + # def iterable_arg(i: list[bool] | str): + # pass + def test_type_check_decorator_dict(self): @type_check_decorator def dict_arg(d: dict[str, bool]): @@ -75,6 +89,22 @@ def dict_arg(d: dict[str, bool]): 'key2': 10}) assert "'d'[key2] must be 'bool'. The given argument type is 'int'" in str(e.value) + # TODO: test_type_check_decorator_dict_key + # with pytest.raises(UnsupportedTypeError) as e: + # dict_arg({'key1': True, + # 1: 10}) + + def test_type_check_decorator_dict_with_union(self): + @type_check_decorator + def dict_arg(d: dict[str, bool | str]): + pass + + with pytest.raises(UnsupportedTypeError) as e: + dict_arg({'key1': True, + 'key2': 10}) + assert "'d'[key2] must be ['bool', 'str']. The given argument type is 'int'"\ + in str(e.value) + def test_type_check_decorator_kwargs(self): @type_check_decorator def kwarg(k: bool):