diff --git a/README.md b/README.md index 097d309..cdae84f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ This framework provides several error handlers to catch errors and call callback (or successes). It comes fully equipped with: - A decorator to handle errors in functions or coroutines -- A decorator to retry a function or coroutine if it fails +- A decorator to retry a function or coroutine if it fails (can be useful for network requests) - A context manager to handle errors in a block of code Additionally, if you use `aiostream` (e.g. using `pip install seviper[aiostream]`), you can use the following features: @@ -31,37 +31,40 @@ pip install seviper[aiostream] ``` ## Usage -Here is a complex example as showcase of the features of this library: +Here is a more or less complex example as showcase of the features of this library: ```python import asyncio +import logging +import sys import aiostream import error_handler -import logging +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) +logger = logging.root op = aiostream.stream.iterate(range(10)) -def log_error(error: Exception): +def log_error(error: Exception, num: int): """Only log error and reraise it""" - logging.error(error) + logger.error("double_only_odd_nums_except_5 failed for input %d. ", num) raise error @error_handler.decorator(on_error=log_error) async def double_only_odd_nums_except_5(num: int) -> int: if num % 2 == 0: raise ValueError(num) - with error_handler.context_manager(on_success=lambda: logging.info(f"Success: {num}")): + with error_handler.context_manager(on_success=lambda: logging.info("Success: %s", num)): if num == 5: raise RuntimeError("Another unexpected error. Number 5 will not be doubled.") num *= 2 return num -def catch_value_errors(error: Exception): +def catch_value_errors(error: Exception, _: int): if not isinstance(error, ValueError): raise error -def log_success(num: int): - logging.info(f"Success: {num}") +def log_success(result_num: int, provided_num: int): + logger.info("Success: %d -> %d", provided_num, result_num) op = op | error_handler.pipe.map( double_only_odd_nums_except_5, @@ -76,6 +79,25 @@ result = asyncio.run(aiostream.stream.list(op)) assert result == [2, 6, 5, 14, 18] ``` +This outputs: + +``` +ERROR:root:double_only_odd_nums_except_5 failed for input 0. +INFO:root:Success: 2 +INFO:root:Success: 1 -> 2 +ERROR:root:double_only_odd_nums_except_5 failed for input 2. +INFO:root:Success: 6 +INFO:root:Success: 3 -> 6 +ERROR:root:double_only_odd_nums_except_5 failed for input 4. +INFO:root:Success: 5 -> 5 +ERROR:root:double_only_odd_nums_except_5 failed for input 6. +INFO:root:Success: 14 +INFO:root:Success: 7 -> 14 +ERROR:root:double_only_odd_nums_except_5 failed for input 8. +INFO:root:Success: 18 +INFO:root:Success: 9 -> 18 +``` + ## How to use this Repository on Your Machine Please refer to the respective section in our [Python template repository](https://github.com/Hochfrequenz/python_template_repository?tab=readme-ov-file#how-to-use-this-repository-on-your-machine) diff --git a/src/error_handler/__init__.py b/src/error_handler/__init__.py index 99192cb..7c4dc58 100644 --- a/src/error_handler/__init__.py +++ b/src/error_handler/__init__.py @@ -4,9 +4,9 @@ """ import importlib +from typing import TYPE_CHECKING from .context_manager import context_manager -from .core import Catcher from .decorator import decorator, retry_on_error from .types import ( ERRORED, @@ -14,13 +14,13 @@ AsyncFunctionType, ErroredType, FunctionType, - NegativeResult, - PositiveResult, - ResultType, SecuredAsyncFunctionType, SecuredFunctionType, UnsetType, ) -stream = importlib.import_module("error_handler.stream") -pipe = importlib.import_module("error_handler.pipe") +if TYPE_CHECKING: + from . import pipe, stream +else: + stream = importlib.import_module("error_handler.stream") + pipe = importlib.import_module("error_handler.pipe") diff --git a/src/error_handler/callback.py b/src/error_handler/callback.py new file mode 100644 index 0000000..314f6ba --- /dev/null +++ b/src/error_handler/callback.py @@ -0,0 +1,150 @@ +""" +This module contains the Callback class, which is used to wrap a callable and its expected signature. +The expected signature is only used to give nicer error messages when the callback is called with the wrong +arguments. Just in case that the type checker is not able to spot callback functions with wrong signatures. +""" + +import inspect +from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar, cast + +from .types import UNSET + +_P = ParamSpec("_P") +_T = TypeVar("_T") +_CallbackT = TypeVar("_CallbackT", bound="Callback") +_ErrorCallbackT = TypeVar("_ErrorCallbackT", bound="ErrorCallback") +_SuccessCallbackT = TypeVar("_SuccessCallbackT", bound="SuccessCallback") + + +class Callback(Generic[_P, _T]): + """ + This class wraps a callable and its expected signature. + """ + + def __init__(self, callback: Callable[_P, _T], expected_signature: inspect.Signature): + self.callback = callback + self.expected_signature = expected_signature + self._actual_signature: inspect.Signature | None = None + + @property + def actual_signature(self) -> inspect.Signature: + """ + The actual signature of the callback + """ + if self._actual_signature is None: + self._actual_signature = inspect.signature(self.callback) + return self._actual_signature + + @property + def expected_signature_str(self) -> str: + """ + The expected signature as string + """ + return str(self.expected_signature) + + @property + def actual_signature_str(self) -> str: + """ + The actual signature as string + """ + return str(self.actual_signature) + + @classmethod + def from_callable( + cls: type[_CallbackT], + callback: Callable, + signature_from_callable: Callable[..., Any] | inspect.Signature | None = None, + add_params: Sequence[inspect.Parameter] | None = None, + return_type: Any = UNSET, + ) -> _CallbackT: + """ + Create a new Callback instance from a callable. The expected signature will be taken from the + signature_from_callable. You can add additional parameters or change the return type for the + expected signature. + """ + if signature_from_callable is None: + sig = inspect.Signature() + elif isinstance(signature_from_callable, inspect.Signature): + sig = signature_from_callable + else: + sig = inspect.signature(signature_from_callable) + if add_params is not None or return_type is not None: + params = list(sig.parameters.values()) + if add_params is not None: + params = [*add_params, *params] + if return_type is UNSET: + return_type = sig.return_annotation + sig = sig.replace(parameters=params, return_annotation=return_type) + return cls(callback, sig) + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: + """ + Call the callback with the given arguments and keyword arguments. The arguments will be checked against the + expected signature. If the callback does not match the expected signature, a TypeError explaining which + signature was expected will be raised. + """ + try: + filled_signature = self.actual_signature.bind(*args, **kwargs) + except TypeError: + # pylint: disable=raise-missing-from + # I decided to leave this out because the original exception is less helpful and spams the stack trace. + # Please read: https://docs.python.org/3/library/exceptions.html#BaseException.__suppress_context__ + raise TypeError( + f"Arguments do not match signature of callback {self.callback.__name__}{self.actual_signature_str}. " + f"Callback function must match signature: {self.callback.__name__}{self.expected_signature_str}" + ) from None + return self.callback(*filled_signature.args, **filled_signature.kwargs) + + +class ErrorCallback(Callback[_P, _T]): + """ + This class wraps an error callback. It is a subclass of Callback and adds the error parameter to the expected + signature. + """ + + _CALLBACK_ERROR_PARAM = inspect.Parameter("error", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Exception) + + @classmethod + def from_callable( + cls: type[_ErrorCallbackT], + callback: Callable, + signature_from_callable: Callable[..., Any] | inspect.Signature | None = None, + add_params: Sequence[inspect.Parameter] | None = None, + return_type: Any = UNSET, + ) -> _ErrorCallbackT: + if add_params is None: + add_params = [] + inst = cast( + _ErrorCallbackT, + super().from_callable( + callback, signature_from_callable, [cls._CALLBACK_ERROR_PARAM, *add_params], return_type + ), + ) + return inst + + +class SuccessCallback(Callback[_P, _T]): + """ + This class wraps a success callback. It is a subclass of Callback and adds the result parameter to the expected + signature. The annotation type is taken from the return annotation of the `signature_from_callable`. + """ + + @classmethod + def from_callable( + cls: type[_SuccessCallbackT], + callback: Callable, + signature_from_callable: Callable[..., Any] | inspect.Signature | None = None, + add_params: Sequence[inspect.Parameter] | None = None, + return_type: Any = UNSET, + ) -> _SuccessCallbackT: + inst = cast(_SuccessCallbackT, super().from_callable(callback, signature_from_callable, add_params)) + add_param = inspect.Parameter( + "result", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=inst.expected_signature.return_annotation + ) + if return_type is UNSET: + return_type = inst.expected_signature.return_annotation + inst.expected_signature = inst.expected_signature.replace( + parameters=[add_param, *inst.expected_signature.parameters.values()], + return_annotation=return_type, + ) + return inst diff --git a/src/error_handler/context_manager.py b/src/error_handler/context_manager.py index 23f5676..8c3c1f1 100644 --- a/src/error_handler/context_manager.py +++ b/src/error_handler/context_manager.py @@ -2,18 +2,22 @@ This module provides a context manager to handle errors in a convenient way. """ -from typing import Any, Callable, ContextManager +from contextlib import contextmanager +from typing import Any, Callable, Iterator +from .callback import Callback, ErrorCallback from .core import Catcher +from .types import UnsetType # pylint: disable=unsubscriptable-object +@contextmanager def context_manager( on_success: Callable[[], Any] | None = None, - on_error: Callable[[Exception], Any] | Callable[[], Any] | None = None, + on_error: Callable[[Exception], Any] | None = None, on_finalize: Callable[[], Any] | None = None, suppress_recalling_on_error: bool = True, -) -> ContextManager[Catcher[None]]: +) -> Iterator[Catcher[UnsetType]]: """ This context manager catches all errors inside the context and calls the corresponding callbacks. It is a shorthand for creating a Catcher instance and using its secure_context method. @@ -24,5 +28,12 @@ def context_manager( If suppress_recalling_on_error is True, the on_error callable will not be called if the error were already caught by a previous catcher. """ - catcher = Catcher[None](on_success, on_error, on_finalize, suppress_recalling_on_error=suppress_recalling_on_error) - return catcher.secure_context() + catcher = Catcher[UnsetType]( + Callback.from_callable(on_success, return_type=Any) if on_success is not None else None, + ErrorCallback.from_callable(on_error, return_type=Any) if on_error is not None else None, + Callback.from_callable(on_finalize, return_type=Any) if on_finalize is not None else None, + suppress_recalling_on_error=suppress_recalling_on_error, + ) + with catcher.secure_context(): + yield catcher + catcher.handle_result_and_call_callbacks(catcher.result) diff --git a/src/error_handler/core.py b/src/error_handler/core.py index 417aff8..1e8dd1c 100644 --- a/src/error_handler/core.py +++ b/src/error_handler/core.py @@ -7,45 +7,29 @@ # Seems like pylint doesn't like the new typing features. It has a problem with the generic T of class Catcher. import inspect from contextlib import contextmanager -from typing import Any, Awaitable, Callable, Generic, Iterator, ParamSpec, Self, TypeVar, overload - -from .types import ERRORED, UNSET, ErroredType, NegativeResult, PositiveResult, ResultType, T, UnsetType +from typing import Any, Awaitable, Callable, Generic, Iterator, ParamSpec, Self, TypeVar + +from .callback import Callback +from .result import ( + CallbackResultType, + CallbackResultTypes, + CallbackSummary, + NegativeResult, + PositiveResult, + ResultType, + ReturnValues, +) +from .types import ERRORED, UNSET, ErroredType, T, UnsetType _T = TypeVar("_T") _U = TypeVar("_U") _P = ParamSpec("_P") -@overload -def _call_callback( - callback: Callable[[_T], _U] | Callable[[], _U], - optional_arg: _T, - raise_if_arg_present: bool = False, -) -> _U: ... - - -@overload -def _call_callback( - callback: Callable[[], _U], - optional_arg: UnsetType = UNSET, - raise_if_arg_present: bool = False, -) -> _U: ... - - -def _call_callback( - callback: Callable[[_T], _U] | Callable[[], _U], - optional_arg: _T | UnsetType = UNSET, - raise_if_arg_present: bool = False, -) -> _U: - signature = inspect.signature(callback) - if len(signature.parameters) == 0: - return callback() # type: ignore[call-arg] - if len(signature.parameters) == 1 and raise_if_arg_present: - raise ValueError(f"Callback {callback.__name__} cannot receive arguments when using for a context manager.") - assert optional_arg is not UNSET, "Internal error: optional_arg is UNSET but should be." - return callback(optional_arg) # type: ignore[call-arg,arg-type] +_CALLBACK_ERROR_PARAM = inspect.Parameter("error", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=BaseException) +# pylint: disable=too-many-instance-attributes class Catcher(Generic[T]): """ After defining callbacks and other options for an instance, you can use the secure_call and secure_await methods @@ -56,11 +40,13 @@ class Catcher(Generic[T]): # pylint: disable=too-many-arguments def __init__( self, - on_success: Callable[[T], Any] | Callable[[], Any] | None = None, - on_error: Callable[[Exception], Any] | Callable[[], Any] | None = None, - on_finalize: Callable[[], Any] | None = None, + on_success: Callback | None = None, + on_error: Callback | None = None, + on_finalize: Callback | None = None, on_error_return_always: T | ErroredType = ERRORED, suppress_recalling_on_error: bool = True, + raise_callback_errors: bool = True, + no_wrap_exception_group_when_reraise: bool = True, ): self.on_success = on_success self.on_error = on_error @@ -73,8 +59,21 @@ def __init__( This is especially useful if you have nested catchers (e.g. due to nested context managers / function calls) which are re-raising the error. """ + self._result: ResultType[T] | None = None + self.raise_callback_errors = raise_callback_errors + self.no_wrap_exception_group_when_reraise = no_wrap_exception_group_when_reraise - def mark_exception(self, error: Exception) -> None: + @property + def result(self) -> ResultType[T]: + """ + This method returns the result of the last execution. If the catcher has not been executed yet, a ValueError + will be raised. + """ + if self._result is None: + raise ValueError("The catcher has not been executed yet.") + return self._result + + def _mark_exception(self, error: BaseException) -> None: """ This method marks the given exception as handled by the catcher. """ @@ -83,46 +82,113 @@ def mark_exception(self, error: Exception) -> None: error.__caught_by_catcher__.append(self) # type: ignore[attr-defined] @staticmethod - def _ensure_exception_in_cause_propagation(error_base: Exception, error_cause: Exception) -> None: - """ - This method ensures that the given error_cause is in the cause chain of the given error_base. - """ - if error_base is error_cause: + def _call_callback(callback: Callback | None, *args: Any, **kwargs: Any) -> tuple[Any, CallbackResultType]: + callback_result = CallbackResultType.SKIPPED + callback_return_value: Any = UNSET + if callback is not None: + try: + callback_return_value = callback(*args, **kwargs) + callback_result = CallbackResultType.SUCCESS + except BaseException as callback_error: # pylint: disable=broad-exception-caught + callback_return_value = callback_error + callback_result = CallbackResultType.ERROR + return callback_return_value, callback_result + + def _raise_callback_errors_if_set(self, result: CallbackSummary, raise_from: BaseException | None = None) -> None: + if not self.raise_callback_errors: return - if error_base.__cause__ is None: - error_base.__cause__ = error_cause - else: - assert isinstance(error_base.__cause__, Exception), "Internal error: __cause__ is not an Exception" - Catcher._ensure_exception_in_cause_propagation(error_base.__cause__, error_cause) - - def handle_error_case(self, error: Exception) -> ResultType[T]: + excs = [] + if result.callback_result_types.success == CallbackResultType.ERROR: + excs.append(result.callback_return_values.success) + if result.callback_result_types.error == CallbackResultType.ERROR: + excs.append(result.callback_return_values.error) + if result.callback_result_types.finalize == CallbackResultType.ERROR: + excs.append(result.callback_return_values.finalize) + + if self.no_wrap_exception_group_when_reraise and len(excs) == 1 and raise_from is excs[0]: + raise raise_from + if len(excs) > 0: + exc_group = BaseExceptionGroup("There were one or more errors while calling the callback functions.", excs) + if raise_from is not None: + exc_group.__context__ = raise_from + raise exc_group + + def _handle_error_callback(self, error: BaseException, *args: Any, **kwargs: Any) -> tuple[Any, CallbackResultType]: """ This method handles the given exception. """ + return_value = UNSET + result = CallbackResultType.SKIPPED caught_before = hasattr(error, "__caught_by_catcher__") - self.mark_exception(error) - if self.on_error is not None and not (caught_before and self.suppress_recalling_on_error): - try: - _call_callback(self.on_error, error) - except Exception as callback_error: # pylint: disable=broad-exception-caught - self._ensure_exception_in_cause_propagation(callback_error, error) - raise callback_error - return NegativeResult(error=error, result=self.on_error_return_always) + self._mark_exception(error) + if not (caught_before and self.suppress_recalling_on_error): + return_value, result = self._call_callback(self.on_error, error, *args, **kwargs) + if result == CallbackResultType.ERROR and return_value is error: + assert self.on_error is not None, "Internal error: on_error is None but result is ERROR" + error.add_note(f"This error was reraised by on_error callback {self.on_error.callback.__name__}") + + return return_value, result - def handle_success_case(self, result: T, raise_if_arg_present: bool = False) -> ResultType[T]: + def _handle_success_callback(self, *args: Any, **kwargs: Any) -> tuple[Any, CallbackResultType]: """ This method handles the given result. """ - if self.on_success is not None: - _call_callback(self.on_success, result, raise_if_arg_present=raise_if_arg_present) # type: ignore[arg-type] - return PositiveResult(result=result) + return self._call_callback(self.on_success, *args, **kwargs) - def handle_finalize_case(self) -> None: + def _handle_finalize_callback(self, *args: Any, **kwargs: Any) -> tuple[Any, CallbackResultType]: """ This method handles the finalize case. """ - if self.on_finalize is not None: - self.on_finalize() + return self._call_callback(self.on_finalize, *args, **kwargs) + + def handle_success_case(self, result: T | UnsetType, *args: Any, **kwargs: Any) -> CallbackSummary: + """ + This method handles the success case. + """ + if result is UNSET: + success_return_value, success_result = self._handle_success_callback(*args, **kwargs) + else: + success_return_value, success_result = self._handle_success_callback(result, *args, **kwargs) + finalize_return_value, finalize_result = self._handle_finalize_callback(*args, **kwargs) + callback_result = CallbackSummary( + callback_result_types=CallbackResultTypes( + success=success_result, + finalize=finalize_result, + ), + callback_return_values=ReturnValues( + success=success_return_value, + finalize=finalize_return_value, + ), + ) + self._raise_callback_errors_if_set(callback_result) + return callback_result + + def handle_error_case(self, error: BaseException, *args: Any, **kwargs: Any) -> CallbackSummary: + """ + This method handles the error case. + """ + error_return_value, error_result = self._handle_error_callback(error, *args, **kwargs) + finalize_return_value, finalize_result = self._handle_finalize_callback(*args, **kwargs) + callback_result = CallbackSummary( + callback_result_types=CallbackResultTypes( + error=error_result, + finalize=finalize_result, + ), + callback_return_values=ReturnValues( + error=error_return_value, + finalize=finalize_return_value, + ), + ) + self._raise_callback_errors_if_set(callback_result, error) + return callback_result + + def handle_result_and_call_callbacks(self, result: ResultType[T], *args: Any, **kwargs: Any) -> CallbackSummary: + """ + This method handles the last case. + """ + if isinstance(result, PositiveResult): + return self.handle_success_case(result.result, *args, **kwargs) + return self.handle_error_case(result.error, *args, **kwargs) def secure_call( # type: ignore[return] # Because mypy is stupid, idk. self, @@ -140,11 +206,10 @@ def secure_call( # type: ignore[return] # Because mypy is stupid, idk. """ try: result = callable_to_secure(*args, **kwargs) - return self.handle_success_case(result) - except Exception as error: # pylint: disable=broad-exception-caught - return self.handle_error_case(error) - finally: - self.handle_finalize_case() + self._result = PositiveResult(result=result) + except BaseException as error: # pylint: disable=broad-exception-caught + self._result = NegativeResult(error=error, result=self.on_error_return_always) + return self.result async def secure_await( # type: ignore[return] # Because mypy is stupid, idk. self, @@ -160,11 +225,10 @@ async def secure_await( # type: ignore[return] # Because mypy is stupid, idk. """ try: result = await awaitable_to_secure - return self.handle_success_case(result) - except Exception as error: # pylint: disable=broad-exception-caught - return self.handle_error_case(error) - finally: - self.handle_finalize_case() + self._result = PositiveResult(result=result) + except BaseException as error: # pylint: disable=broad-exception-caught + self._result = NegativeResult(error=error, result=self.on_error_return_always) + return self.result @contextmanager def secure_context(self) -> Iterator[Self]: @@ -179,8 +243,6 @@ def secure_context(self) -> Iterator[Self]: """ try: yield self - self.handle_success_case(None, raise_if_arg_present=True) # type: ignore[arg-type] - except Exception as error: # pylint: disable=broad-exception-caught - self.handle_error_case(error) - finally: - self.handle_finalize_case() + self._result = PositiveResult(result=UNSET) + except BaseException as error: # pylint: disable=broad-exception-caught + self._result = NegativeResult(error=error, result=self.on_error_return_always) diff --git a/src/error_handler/decorator.py b/src/error_handler/decorator.py index c283aa5..1cab875 100644 --- a/src/error_handler/decorator.py +++ b/src/error_handler/decorator.py @@ -4,27 +4,31 @@ import asyncio import functools +import inspect import logging import time -from typing import Any, Callable, TypeGuard +from typing import Any, Callable, Concatenate, Generator, ParamSpec, TypeGuard, TypeVar, cast +from .callback import Callback, ErrorCallback, SuccessCallback from .core import Catcher +from .result import CallbackResultType, PositiveResult, ResultType from .types import ( ERRORED, AsyncFunctionType, ErroredType, FunctionType, - NegativeResult, - P, - PositiveResult, - ResultType, SecuredAsyncFunctionType, SecuredFunctionType, - T, + UnsetType, ) +_P = ParamSpec("_P") +_T = TypeVar("_T") -def iscoroutinefunction(callable_: FunctionType[P, T] | AsyncFunctionType[P, T]) -> TypeGuard[AsyncFunctionType[P, T]]: + +def iscoroutinefunction( + callable_: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] +) -> TypeGuard[AsyncFunctionType[_P, _T]]: """ This function checks if the given callable is a coroutine function. """ @@ -33,13 +37,13 @@ def iscoroutinefunction(callable_: FunctionType[P, T] | AsyncFunctionType[P, T]) # pylint: disable=too-many-arguments def decorator( - on_success: Callable[[T], Any] | Callable[[], Any] | None = None, - on_error: Callable[[Exception], Any] | Callable[[], Any] | None = None, - on_finalize: Callable[[], Any] | None = None, - on_error_return_always: T | ErroredType = ERRORED, + on_success: Callable[Concatenate[_T, _P], Any] | None = None, + on_error: Callable[Concatenate[Exception, _P], Any] | None = None, + on_finalize: Callable[_P, Any] | None = None, + on_error_return_always: _T | ErroredType = ERRORED, suppress_recalling_on_error: bool = True, ) -> Callable[ - [FunctionType[P, T] | AsyncFunctionType[P, T]], SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T] + [FunctionType[_P, _T] | AsyncFunctionType[_P, _T]], SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T] ]: """ This decorator secures a callable (sync or async) and handles its errors. @@ -53,53 +57,62 @@ def decorator( caught by a previous catcher. """ # pylint: disable=unsubscriptable-object - catcher = Catcher[T](on_success, on_error, on_finalize, on_error_return_always, suppress_recalling_on_error) def decorator_inner( - callable_to_secure: FunctionType[P, T] | AsyncFunctionType[P, T] - ) -> SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T]: + callable_to_secure: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] + ) -> SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T]: + sig = inspect.signature(callable_to_secure) + catcher = Catcher[_T]( + SuccessCallback.from_callable(on_success, sig, return_type=Any) if on_success is not None else None, + ErrorCallback.from_callable(on_error, sig, return_type=Any) if on_error is not None else None, + Callback.from_callable(on_finalize, sig, return_type=Any) if on_finalize is not None else None, + on_error_return_always, + suppress_recalling_on_error, + ) if iscoroutinefunction(callable_to_secure): @functools.wraps(callable_to_secure) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | ErroredType: - return ( - await catcher.secure_await(callable_to_secure(*args, **kwargs)) - ).result # type: ignore[return-value] - - # Incompatible return value type (got "object", expected "T") [return-value] - # Seems like mypy isn't good enough for this. + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T | ErroredType: + result = await catcher.secure_await(callable_to_secure(*args, **kwargs)) + catcher.handle_result_and_call_callbacks(result, *args, **kwargs) + assert not isinstance(result.result, UnsetType), "Internal error: result is unset" + return result.result else: @functools.wraps(callable_to_secure) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T | ErroredType: - return catcher.secure_call( # type: ignore[return-value] + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T | ErroredType: + result = catcher.secure_call( callable_to_secure, # type: ignore[arg-type] *args, **kwargs, - ).result - # Incompatible return value type (got "object", expected "T") [return-value] - # Seems like mypy isn't good enough for this. + ) + catcher.handle_result_and_call_callbacks(result, *args, **kwargs) + assert not isinstance(result.result, UnsetType), "Internal error: result is unset" + return result.result - wrapper.__catcher__ = catcher # type: ignore[attr-defined] - wrapper.__original_callable__ = callable_to_secure # type: ignore[attr-defined] - return wrapper + return_func = cast(SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T], wrapper) + return_func.__catcher__ = catcher + return_func.__original_callable__ = callable_to_secure + return return_func return decorator_inner -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, too-many-locals def retry_on_error( - on_error: Callable[[Exception, int], bool], + on_error: Callable[Concatenate[Exception, int, _P], bool], retry_stepping_func: Callable[[int], float] = lambda retry_count: 1.71**retry_count, # <-- with max_retries = 10 the whole decorator may wait up to 5 minutes. # because sum(1.71seconds**i for i in range(10)) == 5minutes max_retries: int = 10, - on_success: Callable[[T, int], Any] | None = None, - on_fail: Callable[[Exception, int], Any] | None = None, - on_finalize: Callable[[int], Any] | None = None, + on_success: Callable[Concatenate[_T, int, _P], Any] | None = None, + on_fail: Callable[Concatenate[Exception, int, _P], Any] | None = None, + on_finalize: Callable[Concatenate[int, _P], Any] | None = None, logger: logging.Logger = logging.getLogger(__name__), -) -> Callable[[FunctionType[P, T]], FunctionType[P, T]]: +) -> Callable[ + [FunctionType[_P, _T] | AsyncFunctionType[_P, _T]], SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T] +]: """ This decorator retries a callable (sync or async) on error. The retry_stepping_func is called with the retry count and should return the time to wait until the next retry. @@ -110,50 +123,96 @@ def retry_on_error( In this case, the on_fail callback will be called and the respective error will be raised. You can additionally use the normal decorator on top of that if you don't want an exception to be raised. """ - # pylint: disable=unsubscriptable-object - catcher = Catcher[T]() - - def handle_result(result: ResultType[T], retry_count: int) -> bool: - if isinstance(result, NegativeResult): - if not on_error(result.error, retry_count): - if on_fail is not None: - on_fail(result.error, retry_count) - if on_finalize is not None: - on_finalize(retry_count) - raise result.error - return True - if on_success is not None: - on_success(result.result, retry_count) - if on_finalize is not None: - on_finalize(retry_count) - return False - - def too_many_retries_error_handler(callback_name: str, max_retries: int) -> Exception: - too_many_retries_error = RuntimeError(f"Too many retries ({max_retries}) for {callback_name}") - if on_fail is not None: - on_fail(too_many_retries_error, max_retries) - if on_finalize is not None: - on_finalize(max_retries) - return too_many_retries_error def decorator_inner( - callable_to_secure: FunctionType[P, T] | AsyncFunctionType[P, T] - ) -> SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T]: + callable_to_secure: FunctionType[_P, _T] | AsyncFunctionType[_P, _T] + ) -> SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T]: + sig = inspect.signature(callable_to_secure) + sig = sig.replace( + parameters=[ + inspect.Parameter("retries", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int), + *sig.parameters.values(), + ], + ) + on_error_callback: ErrorCallback[Concatenate[Exception, int, _P], bool] = ErrorCallback.from_callable( + on_error, sig, return_type=bool + ) + on_success_callback: SuccessCallback[Concatenate[_T, int, _P], Any] | None = ( + SuccessCallback.from_callable(on_success, sig, return_type=Any) if on_success is not None else None + ) + on_fail_callback: ErrorCallback[Concatenate[Exception, int, _P], Any] | None = ( + ErrorCallback.from_callable(on_fail, sig, return_type=Any) if on_fail is not None else None + ) + on_finalize_callback: Callback[Concatenate[int, _P], Any] | None = ( + Callback.from_callable(on_finalize, sig, return_type=Any) if on_finalize is not None else None + ) + + # pylint: disable=unsubscriptable-object + catcher_executor = Catcher[_T](on_error=on_error_callback) + catcher_retrier = Catcher[_T]( + on_success=on_success_callback, + on_error=on_fail_callback, + on_finalize=on_finalize_callback, + suppress_recalling_on_error=False, + ) + retry_count = 0 + + def retry_generator(*args: _P.args, **kwargs: _P.kwargs) -> Generator[int, ResultType[_T], _T]: + nonlocal retry_count + for retry_count_i in range(max_retries): + result: ResultType[_T] = yield retry_count_i + retry_count = retry_count_i + if isinstance(result, PositiveResult): + assert not isinstance(result.result, UnsetType), "Internal error: result is unset" + return result.result + callback_summary = catcher_executor.handle_result_and_call_callbacks( + result, retry_count_i, *args, **kwargs + ) + assert ( + callback_summary.callback_result_types.error == CallbackResultType.SUCCESS + ), "Internal error: on_error callback was not successful but didn't raise exception" + if callback_summary.callback_return_values.error is True: + yield retry_count_i + continue + # Should not retry + raise result.error + + retry_count = max_retries + error = RuntimeError(f"Too many retries ({max_retries}) for {callable_to_secure.__name__}") + raise error + + def handle_result_and_call_callbacks(result: ResultType[_T], *args: _P.args, **kwargs: _P.kwargs) -> _T: + if isinstance(result, PositiveResult): + assert not isinstance(result.result, UnsetType), "Internal error: result is unset" + catcher_retrier.handle_success_case( + result.result, + retry_count, + *args, + **kwargs, + ) + return result.result + + catcher_retrier.handle_error_case(result.error, retry_count, *args, **kwargs) + raise result.error + if iscoroutinefunction(callable_to_secure): - @functools.wraps(callable_to_secure) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - for retry_count in range(max_retries): - result = await catcher.secure_await(callable_to_secure(*args, **kwargs)) - if handle_result(result, retry_count): - # Should retry - await asyncio.sleep(retry_stepping_func(retry_count)) - continue - # Should not retry because the result is positive - assert isinstance(result, PositiveResult), "Internal error: NegativeResult was not handled properly" - return result.result + async def retry_function_async(*args: _P.args, **kwargs: _P.kwargs) -> _T: + generator = retry_generator(*args, **kwargs) + while True: + next(generator) + try: + retry_count_ = generator.send( + await catcher_executor.secure_await(callable_to_secure(*args, **kwargs)) + ) + except StopIteration as stop_iteration: + return stop_iteration.value + await asyncio.sleep(retry_stepping_func(retry_count_)) - raise too_many_retries_error_handler(callable_to_secure.__name__, max_retries) + @functools.wraps(callable_to_secure) + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + result = await catcher_retrier.secure_await(retry_function_async(*args, **kwargs)) + return handle_result_and_call_callbacks(result, *args, **kwargs) else: logger.warning( @@ -162,27 +221,26 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: "Please consider decorating an async function instead." ) - @functools.wraps(callable_to_secure) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - for retry_count in range(max_retries): - result = catcher.secure_call(callable_to_secure, *args, **kwargs) # type: ignore[arg-type] - if handle_result(result, retry_count): - # Should retry - time.sleep(retry_stepping_func(retry_count)) - continue - # Should not retry because the result is positive - assert isinstance(result, PositiveResult), "Internal error: NegativeResult was not handled properly" - return result.result + def retry_function_sync(*args: _P.args, **kwargs: _P.kwargs) -> _T: + generator = retry_generator(*args, **kwargs) + while True: + next(generator) + try: + retry_count_ = generator.send( + catcher_executor.secure_call(callable_to_secure, *args, **kwargs) # type: ignore[arg-type] + ) + except StopIteration as stop_iteration: + return stop_iteration.value + time.sleep(retry_stepping_func(retry_count_)) - too_many_retries_error = RuntimeError( - f"Too many retries ({max_retries}) for {callable_to_secure.__name__}" - ) - if on_fail is not None: - on_fail(too_many_retries_error, max_retries) - raise too_many_retries_error + @functools.wraps(callable_to_secure) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: + result = catcher_retrier.secure_call(retry_function_sync, *args, **kwargs) + return handle_result_and_call_callbacks(result, *args, **kwargs) - wrapper.__catcher__ = catcher # type: ignore[attr-defined] - wrapper.__original_callable__ = callable_to_secure # type: ignore[attr-defined] - return wrapper + return_func = cast(SecuredFunctionType[_P, _T] | SecuredAsyncFunctionType[_P, _T], wrapper) + return_func.__catcher__ = catcher_retrier + return_func.__original_callable__ = callable_to_secure + return return_func - return decorator_inner # type: ignore[return-value] + return decorator_inner diff --git a/src/error_handler/result.py b/src/error_handler/result.py new file mode 100644 index 0000000..5d64528 --- /dev/null +++ b/src/error_handler/result.py @@ -0,0 +1,75 @@ +""" +This module contains classes to encapsulate some information about the result of a secured context and its callbacks. +""" + +from dataclasses import dataclass +from enum import StrEnum +from typing import Any, Generic, TypeAlias + +from error_handler.types import UNSET, ErroredType, T, UnsetType + + +class CallbackResultType(StrEnum): + """ + Determines whether a callback was successful or errored or not called aka skipped. + """ + + SUCCESS = "success" + ERROR = "error" + SKIPPED = "skipped" + + +@dataclass(frozen=True) +class CallbackResultTypes: + """ + Contains the information for each callback whether a callback was successful or errored or not called aka skipped. + """ + + success: CallbackResultType = CallbackResultType.SKIPPED + error: CallbackResultType = CallbackResultType.SKIPPED + finalize: CallbackResultType = CallbackResultType.SKIPPED + + +@dataclass(frozen=True) +class ReturnValues: + """ + Contains the return values of each callback. + If a callback was not called, the value is UNSET. + If a callback errored, the value is the error. + """ + + success: Any = UNSET + error: Any = UNSET + finalize: Any = UNSET + + +@dataclass(frozen=True) +class CallbackSummary: + """ + Contains the information of the result of a secured context and its callbacks. + """ + + callback_result_types: CallbackResultTypes + callback_return_values: ReturnValues + + +@dataclass(frozen=True) +class PositiveResult(Generic[T]): + """ + Represents a successful result. + """ + + result: T | UnsetType + + +@dataclass(frozen=True) +class NegativeResult(Generic[T]): + """ + Represents an errored result. + """ + + result: T | ErroredType + error: BaseException + + +ResultType: TypeAlias = PositiveResult[T] | NegativeResult[T] diff --git a/src/error_handler/stream.py b/src/error_handler/stream.py index 462a772..f8cf790 100644 --- a/src/error_handler/stream.py +++ b/src/error_handler/stream.py @@ -8,23 +8,28 @@ from ._extra import IS_AIOSTREAM_INSTALLED from .decorator import decorator -from .types import ERRORED +from .types import ERRORED, AsyncFunctionType, FunctionType, SecuredAsyncFunctionType, SecuredFunctionType, is_secured if IS_AIOSTREAM_INSTALLED: import aiostream - from aiostream.stream.combine import MapCallable, T, U + from aiostream.stream.combine import T, U # pylint: disable=too-many-arguments, redefined-builtin @aiostream.pipable_operator def map( source: AsyncIterable[T], - func: MapCallable[T, U], + func: ( + FunctionType[[T], U] + | AsyncFunctionType[[T], U] + | SecuredFunctionType[[T], U] + | SecuredAsyncFunctionType[[T], U] + ), *more_sources: AsyncIterable[T], ordered: bool = True, task_limit: int | None = None, - on_success: Callable[[U], Any] | Callable[[], Any] | None = None, - on_error: Callable[[Exception], Any] | Callable[[], Any] | None = None, - on_finalize: Callable[[], Any] | None = None, + on_success: Callable[[U, T], Any] | None = None, + on_error: Callable[[Exception, T], Any] | None = None, + on_finalize: Callable[[T], Any] | None = None, wrap_secured_function: bool = False, suppress_recalling_on_error: bool = True, logger: logging.Logger = logging.getLogger(__name__), @@ -35,7 +40,7 @@ def map( If suppress_recalling_on_error is True, the on_error callable will not be called if the error were already caught by a previous catcher. """ - if not wrap_secured_function and hasattr(func, "__catcher__"): + if not wrap_secured_function and is_secured(func): if func.__catcher__.on_error_return_always is not ERRORED: raise ValueError( "The given function is already secured but does not return ERRORED in error case. " @@ -52,9 +57,6 @@ def map( "Please do not set on_success, on_error, on_finalize as they would be ignored. " "You can set wrap_secured_function=True to wrap the secured function with another catcher." ) - assert hasattr( - func, "__original_callable__" - ), "Internal error: The secured function has no __original_callable__ but __catcher__ defined" logger.debug( f"The given function {func.__original_callable__.__name__} is already secured. Using it as is." ) @@ -69,10 +71,10 @@ def map( )( func # type: ignore[arg-type] ) - # mypy complains because for mypy Callable[P, A | B] is not a subtype of Callable[P, A] | Callable[P, B]. - # Which is kinda true but in practice this is equivalent. So just ignore this. - next_source = aiostream.stream.map.raw( - source, secured_func, *more_sources, ordered=ordered, task_limit=task_limit + # Ignore that T | ErroredType is not compatible with T. All ErroredType results are filtered out + # in a subsequent step. + next_source: AsyncIterator[U] = aiostream.stream.map.raw( + source, secured_func, *more_sources, ordered=ordered, task_limit=task_limit # type: ignore[arg-type] ) next_source = aiostream.stream.filter.raw(next_source, lambda result: result is not ERRORED) return next_source diff --git a/src/error_handler/types.py b/src/error_handler/types.py index 9236df9..0a40eaa 100644 --- a/src/error_handler/types.py +++ b/src/error_handler/types.py @@ -3,15 +3,16 @@ """ import inspect -from abc import ABC -from dataclasses import dataclass -from typing import Awaitable, Callable, Generic, ParamSpec, TypeAlias, TypeVar +from typing import TYPE_CHECKING, Awaitable, Callable, ParamSpec, Protocol, TypeAlias, TypeGuard, TypeVar + +if TYPE_CHECKING: + from .core import Catcher T = TypeVar("T") P = ParamSpec("P") -class Singleton(type): +class SingletonMeta(type): """ A metaclass implementing the singleton pattern. """ @@ -51,7 +52,7 @@ def __singleton_new__(cls, *args, **kwargs): # pylint: disable=too-few-public-methods -class ErroredType(metaclass=Singleton): +class ErroredType(metaclass=SingletonMeta): """ This type is meant to be used as singleton. Do not instantiate it on your own. The instance below represents an errored result. @@ -59,7 +60,7 @@ class ErroredType(metaclass=Singleton): # pylint: disable=too-few-public-methods -class UnsetType(metaclass=Singleton): +class UnsetType(metaclass=SingletonMeta): """ This type is meant to be used as singleton. Do not instantiate it on your own. The instance below represents an unset value. It is needed as default value since the respective @@ -67,44 +68,56 @@ class UnsetType(metaclass=Singleton): """ -@dataclass -class Result(Generic[T], ABC): - """ - Represents a result of a function call. - """ +UNSET = UnsetType() +""" +Represents an unset value. It is used as default value for parameters that can be of any type. +""" +ERRORED = ErroredType() +""" +Represents an errored result. It is used to be able to return something in error cases. See Catcher.secure_call +for more information. +""" + +FunctionType: TypeAlias = Callable[P, T] +AsyncFunctionType: TypeAlias = Callable[P, Awaitable[T]] -@dataclass -class PositiveResult(Result[T]): + +class SecuredFunctionType(Protocol[P, T]): """ - Represents a successful result. + This type represents a secured function. """ - result: T + __catcher__: "Catcher[T]" + __original_callable__: FunctionType[P, T] + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T | ErroredType: ... -@dataclass -class NegativeResult(Result[T]): +class SecuredAsyncFunctionType(Protocol[P, T]): """ - Represents an errored result. + This type represents a secured async function. """ - result: T | ErroredType - error: Exception + __catcher__: "Catcher[T]" + __original_callable__: AsyncFunctionType[P, T] + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T | ErroredType]: ... -UNSET = UnsetType() -""" -Represents an unset value. It is used as default value for parameters that can be of any type. -""" -ERRORED = ErroredType() -""" -Represents an errored result. It is used to be able to return something in error cases. See Catcher.secure_call -for more information. -""" -FunctionType: TypeAlias = Callable[P, T] -AsyncFunctionType: TypeAlias = Callable[P, Awaitable[T]] -SecuredFunctionType: TypeAlias = Callable[P, T | ErroredType] -SecuredAsyncFunctionType: TypeAlias = Callable[P, Awaitable[T | ErroredType]] -ResultType: TypeAlias = PositiveResult[T] | NegativeResult[T] +def is_secured( + func: FunctionType[P, T] | SecuredFunctionType[P, T] | AsyncFunctionType[P, T] | SecuredAsyncFunctionType[P, T] +) -> TypeGuard[SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T]]: + """ + Returns True if the given function is secured, False otherwise. + """ + return hasattr(func, "__catcher__") and hasattr(func, "__original_callable__") + + +def is_unsecured( + func: FunctionType[P, T] | AsyncFunctionType[P, T] | SecuredFunctionType[P, T] | SecuredAsyncFunctionType[P, T] +) -> TypeGuard[FunctionType[P, T] | AsyncFunctionType[P, T]]: + """ + Returns True if the given function is not secured, False otherwise. + """ + return not hasattr(func, "__catcher__") or not hasattr(func, "__original_callable__") diff --git a/unittests/test_callback_errors.py b/unittests/test_callback_errors.py new file mode 100644 index 0000000..5b6f351 --- /dev/null +++ b/unittests/test_callback_errors.py @@ -0,0 +1,60 @@ +import pytest + +import error_handler + +from .utils import assert_not_called, create_callback_tracker + + +class TestCallbackErrors: + def test_decorator_callbacks_wrong_signature_call_all_callbacks(self): + on_success_callback, success_tracker = create_callback_tracker() + + def on_finalize_wrong_signature(): + pass + + @error_handler.decorator( + on_success=on_success_callback, on_error=assert_not_called, on_finalize=on_finalize_wrong_signature + ) + def func(hello: str) -> str: + return f"Hello {hello}" + + with pytest.raises(BaseExceptionGroup) as error: + func("World!") + + assert len(error.value.exceptions) == 1 + assert isinstance(error.value.exceptions[0], TypeError) + assert "Arguments do not match signature of callback" in str(error.value.exceptions[0]) + assert "on_finalize_wrong_signature()" in str(error.value.exceptions[0]) + assert "on_finalize_wrong_signature(hello: str) -> Any" in str(error.value.exceptions[0]) + assert success_tracker == [(("Hello World!", "World!"), {})] + + def test_decorator_callbacks_wrong_signature_and_unexpected_error(self): + + def on_finalize_wrong_signature(): + pass + + def on_error_callback(_: BaseException, __: str): + raise ValueError("This is a test error") + + @error_handler.decorator( + on_success=assert_not_called, on_error=on_error_callback, on_finalize=on_finalize_wrong_signature + ) + def func(hello: str) -> str: + raise ValueError(f"This is a test error {hello}") + + with pytest.raises(BaseExceptionGroup) as error: + func("World!") + + assert len(error.value.exceptions) == 2 + if isinstance(error.value.exceptions[0], ValueError): + value_error, wrong_signature_error = error.value.exceptions + else: + wrong_signature_error, value_error = error.value.exceptions + + assert isinstance(wrong_signature_error, TypeError) + assert "Arguments do not match signature of callback" in str(wrong_signature_error) + assert "on_finalize_wrong_signature()" in str(wrong_signature_error) + assert "on_finalize_wrong_signature(hello: str) -> Any" in str(wrong_signature_error) + assert isinstance(value_error, ValueError) + assert str(value_error) == "This is a test error" + assert "This is a test error World!" in str(error.value.__context__) diff --git a/unittests/test_complex_use_case.py b/unittests/test_complex_use_case.py index 34b124d..d74c2eb 100644 --- a/unittests/test_complex_use_case.py +++ b/unittests/test_complex_use_case.py @@ -1,4 +1,5 @@ import logging +import sys import aiostream @@ -7,11 +8,13 @@ class TestComplexExample: async def test_complex_use_case(self): + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) + logger = logging.root op = aiostream.stream.iterate(range(10)) - def log_error(error: Exception): + def log_error(error: Exception, num: int): """Only log error and reraise it""" - logging.error(error) + logger.error("double_only_odd_nums_except_5 failed for input %d. ", num) raise error @error_handler.decorator(on_error=log_error) @@ -24,12 +27,12 @@ async def double_only_odd_nums_except_5(num: int) -> int: num *= 2 return num - def catch_value_errors(error: Exception): + def catch_value_errors(error: Exception, _: int): if not isinstance(error, ValueError): raise error - def log_success(num: int): - logging.info("Success: %s", num) + def log_success(result_num: int, provided_num: int): + logger.info("Success: %d -> %d", provided_num, result_num) op = op | error_handler.pipe.map( double_only_odd_nums_except_5, diff --git a/unittests/test_concurrency.py b/unittests/test_concurrency.py new file mode 100644 index 0000000..8fb477b --- /dev/null +++ b/unittests/test_concurrency.py @@ -0,0 +1,75 @@ +import asyncio +from math import floor + +import error_handler + +from .utils import assert_not_called, create_callback_tracker + + +class TestConcurrency: + async def test_concurrency_decorator_callback_order(self): + counter = 0 + counter_list = [] + + def log_counter(_: str, __: str): + nonlocal counter + counter_list.append(counter) + + on_success_callback, success_tracker = create_callback_tracker(log_counter) + on_finalize_callback, finalize_tracker = create_callback_tracker() + + @error_handler.decorator( + on_success=on_success_callback, + on_finalize=on_finalize_callback, + on_error=assert_not_called, + ) + async def async_function(hello: str) -> str: + nonlocal counter + counter += 1 + await asyncio.sleep(0.1) + return f"Hello {hello}" + + results = await asyncio.gather(async_function("World!"), async_function("world...")) + assert set(results) == {"Hello World!", "Hello world..."} + assert success_tracker == [(("Hello World!", "World!"), {}), (("Hello world...", "world..."), {})] + assert finalize_tracker == [(("World!",), {}), (("world...",), {})] + assert counter_list == [2, 2] + + async def test_concurrency_retry(self): + retry_counter = 0 + + def error_callback(_: Exception, retry_count: int, __: str) -> bool: + nonlocal retry_counter + assert retry_count == floor(retry_counter / 2) + retry_counter += 1 + return True + + error_callback, error_tracker = create_callback_tracker(additional_callback=error_callback) + success_callback, success_tracker = create_callback_tracker() + finalize_callback, finalize_tracker = create_callback_tracker() + + @error_handler.retry_on_error( + on_error=error_callback, + on_success=success_callback, + on_finalize=finalize_callback, + on_fail=assert_not_called, + retry_stepping_func=lambda _: 0.01, + ) + async def async_function(hello: str) -> str: + nonlocal retry_counter + if hello == "world...": + # "guarantee" that the first raise is the one for "World!" + await asyncio.sleep(0.005) + await asyncio.sleep(0.02) + if retry_counter < 4: + raise ValueError(retry_counter) + return f"Hello {hello}" + + results = await asyncio.gather(async_function("World!"), async_function("world...")) + assert set(results) == {"Hello World!", "Hello world..."} + assert len(error_tracker) == 4 + error_tracker_args = [(error.args[0], num, hello) for (error, num, hello), _ in error_tracker] + assert set(error_tracker_args[:2]) == {(0, 0, "World!"), (1, 0, "world...")} + assert set(error_tracker_args[2:]) == {(2, 1, "World!"), (3, 1, "world...")} + assert success_tracker == [(("Hello World!", 2, "World!"), {}), (("Hello world...", 2, "world..."), {})] + assert finalize_tracker == [((2, "World!"), {}), ((2, "world..."), {})] diff --git a/unittests/test_context_manager.py b/unittests/test_context_manager.py index 2a1ed09..c35451e 100644 --- a/unittests/test_context_manager.py +++ b/unittests/test_context_manager.py @@ -2,6 +2,8 @@ import error_handler +from .utils import assert_not_called + class TestErrorHandlerContextManager: def test_context_manager_error_case(self): @@ -26,7 +28,7 @@ def succeeded_callback(): nonlocal succeeded succeeded = True - with error_handler.context_manager(on_success=succeeded_callback): + with error_handler.context_manager(on_success=succeeded_callback, on_error=assert_not_called): pass assert succeeded @@ -40,9 +42,7 @@ def store_error(error: Exception): raise error with pytest.raises(ValueError) as error: - with error_handler.context_manager( - on_error=store_error, - ): + with error_handler.context_manager(on_error=store_error): raise ValueError("This is a test error world") assert isinstance(catched_error, ValueError) diff --git a/unittests/test_decorator.py b/unittests/test_decorator.py index 6024625..0114124 100644 --- a/unittests/test_decorator.py +++ b/unittests/test_decorator.py @@ -1,30 +1,8 @@ -from typing import Any, Callable - import pytest import error_handler - -def assert_not_called(*args, **kwargs): - raise ValueError("This should not be called") - - -def create_callback_tracker( - additional_callback: Callable = lambda *args: None, -) -> tuple[Callable, list[tuple[Any, ...]]]: - """ - Creates a callback function taking any arguments and a tracker which stores all calls to the callback in a list. - The callback function will call the additional_callback with the same arguments and return its return value. - It will also store the arguments as tuple in the call_args list. - Returns the callback and the call_args list. - """ - call_args = [] - - def callback(*args): - call_args.append(args) - return additional_callback(*args) - - return callback, call_args +from .utils import assert_not_called, create_callback_tracker def retry_stepping_func(_: int) -> float: @@ -48,9 +26,10 @@ async def async_function(hello: str) -> None: awaitable = async_function("world") result = await awaitable - assert str(error_tracker[0][0]) == "This is a test error world" + assert str(error_tracker[0][0][0]) == "This is a test error world" + assert str(error_tracker[0][0][1]) == "world" assert result == error_handler.ERRORED - assert finalize_tracker == [()] + assert finalize_tracker == [(("world",), {})] async def test_decorator_coroutine_success_case(self): on_success_callback, success_tracker = create_callback_tracker() @@ -66,45 +45,38 @@ async def async_function(hello: str) -> str: awaitable = async_function("World!") result = await awaitable - assert result == success_tracker[0][0] == "Hello World!" - assert finalize_tracker == [()] + assert result == success_tracker[0][0][0] == "Hello World!" + assert success_tracker[0][0][1] == "World!" + assert finalize_tracker == [(("World!",), {})] def test_decorator_function_error_case(self): - catched_error: Exception | None = None - - def store_error(error: Exception): - nonlocal catched_error - catched_error = error + error_callback, error_tracker = create_callback_tracker() - @error_handler.decorator(on_error=store_error) + @error_handler.decorator(on_error=error_callback, on_success=assert_not_called) def func(hello: str) -> None: raise ValueError(f"This is a test error {hello}") result = func("world") - assert isinstance(catched_error, ValueError) - assert str(catched_error) == "This is a test error world" + assert isinstance(error_tracker[0][0][0], ValueError) + assert str(error_tracker[0][0][0]) == "This is a test error world" + assert error_tracker[0][0][1] == "world" assert result == error_handler.ERRORED def test_decorator_function_success_case(self): - return_value: str | None = None - - def store_return_value(value: str): - nonlocal return_value - return_value = value + on_success_callback, success_tracker = create_callback_tracker() - @error_handler.decorator( - on_success=store_return_value, - ) - def async_function(hello: str) -> str: + @error_handler.decorator(on_success=on_success_callback, on_error=assert_not_called) + def func(hello: str) -> str: return f"Hello {hello}" - result = async_function("World!") - assert result == return_value == "Hello World!" + result = func("World!") + assert result == success_tracker[0][0][0] == "Hello World!" + assert success_tracker[0][0][1] == "World!" async def test_decorator_reraise_coroutine(self): catched_error: Exception | None = None - def store_error(error: Exception): + def store_error(error: Exception, _: str): nonlocal catched_error catched_error = error raise error @@ -124,7 +96,7 @@ async def async_function(hello: str) -> None: def test_decorator_reraise_function(self): catched_error: Exception | None = None - def store_error(error: Exception): + def store_error(error: Exception, _: str): nonlocal catched_error catched_error = error raise error @@ -143,7 +115,7 @@ def func(hello: str) -> None: async def test_retry_coroutine_return_after_retries(self): retry_counter = 0 - def error_callback(_: Exception, retry_count: int) -> bool: + def error_callback(_: Exception, retry_count: int, __: str) -> bool: nonlocal retry_counter assert retry_count == retry_counter retry_counter += 1 @@ -170,10 +142,10 @@ async def async_function(hello: str) -> str: result = await awaitable assert retry_counter == 2 - assert [error.args[0] for error, _ in error_tracker] == [0, 1] + assert [error.args[0] for (error, _, __), ___ in error_tracker] == [0, 1] assert result == "Hello world" - assert success_tracker == [("Hello world", 2)] - assert finalize_tracker == [(2,)] + assert success_tracker == [(("Hello world", 2, "world"), {})] + assert finalize_tracker == [((2, "world"), {})] async def test_retry_coroutine_return_without_retries(self): @error_handler.retry_on_error(on_error=assert_not_called, retry_stepping_func=retry_stepping_func) @@ -188,7 +160,7 @@ async def async_function(hello: str) -> str: async def test_retry_coroutine_fail_too_many_retries(self): retry_counter = 0 - def error_callback(_: Exception, retry_count: int) -> bool: + def error_callback(_: Exception, retry_count: int, __: str) -> bool: nonlocal retry_counter assert retry_count == retry_counter retry_counter += 1 @@ -214,14 +186,14 @@ async def async_function(_: str) -> str: _ = await awaitable assert str(error.value) == "Too many retries (2) for async_function" - assert [error.args[0] for error, _ in error_tracker] == [0, 1] - assert fail_tracker == [(error.value, 2)] - assert finalize_tracker == [(2,)] + assert [error.args[0] for (error, _, __), ___ in error_tracker] == [0, 1] + assert fail_tracker == [((error.value, 2, "world"), {})] + assert finalize_tracker == [((2, "world"), {})] async def test_retry_coroutine_fail_callback_returns_false(self): retry_counter = 0 - def error_callback(_: Exception, retry_count: int) -> bool: + def error_callback(_: Exception, retry_count: int, __: str) -> bool: nonlocal retry_counter assert retry_count == retry_counter if retry_counter >= 2: @@ -248,14 +220,14 @@ async def async_function(_: str) -> str: _ = await awaitable assert error.value.args[0] == 2 - assert [error.args[0] for error, _ in error_tracker] == [0, 1, 2] - assert fail_tracker == [(error.value, 2)] - assert finalize_tracker == [(2,)] + assert [error.args[0] for (error, _, __), ___ in error_tracker] == [0, 1, 2] + assert fail_tracker == [((error.value, 2, "world"), {})] + assert finalize_tracker == [((2, "world"), {})] def test_retry_function_return_after_retries(self): retry_counter = 0 - def error_callback(_: Exception, retry_count: int) -> bool: + def error_callback(_: Exception, retry_count: int, __: str) -> bool: nonlocal retry_counter assert retry_count == retry_counter retry_counter += 1 @@ -272,16 +244,16 @@ def error_callback(_: Exception, retry_count: int) -> bool: on_fail=assert_not_called, retry_stepping_func=retry_stepping_func, ) - def sync_function(hello: str) -> str: + def func(hello: str) -> str: nonlocal retry_counter if retry_counter < 2: raise ValueError(retry_counter) return f"Hello {hello}" - result = sync_function("world") + result = func("world") assert retry_counter == 2 - assert [error.args[0] for error, _ in error_tracker] == [0, 1] + assert [error.args[0] for (error, _, __), ___ in error_tracker] == [0, 1] assert result == "Hello world" - assert success_tracker == [("Hello world", 2)] - assert finalize_tracker == [(2,)] + assert success_tracker == [(("Hello world", 2, "world"), {})] + assert finalize_tracker == [((2, "world"), {})] diff --git a/unittests/test_pipable_operators.py b/unittests/test_pipable_operators.py index 19d4aa5..f08e2de 100644 --- a/unittests/test_pipable_operators.py +++ b/unittests/test_pipable_operators.py @@ -3,6 +3,8 @@ import error_handler +from .utils import create_callback_tracker + class TestErrorHandlerPipableOperators: def test_aiostream_import_error(self, trigger_aiostream_import_error): @@ -51,7 +53,7 @@ def raise_for_even(num: int) -> int: raise ValueError(f"{num}") return num - def store(error: Exception): + def store(error: Exception, _: int): nonlocal errored_nums errored_nums.add(int(str(error))) @@ -65,11 +67,11 @@ async def test_secured_map_stream_double_secure_invalid_return_value(self): op = stream.iterate([1, 2, 3, 4, 5, 6]) @error_handler.decorator(on_error_return_always=0) - def raise_for_even(_: int) -> int: + def return_1(_: int) -> int: return 1 with pytest.raises(ValueError) as error: - _ = error_handler.stream.map(op, raise_for_even) + _ = error_handler.stream.map(op, return_1) assert "The given function is already secured but does not return ERRORED in error case" in str(error.value) @@ -77,44 +79,45 @@ async def test_secured_map_stream_double_secure_invalid_arguments(self): op = stream.iterate([1, 2, 3, 4, 5, 6]) @error_handler.decorator() - def raise_for_even(_: int) -> int: + def return_1(_: int) -> int: return 1 with pytest.raises(ValueError) as error: - _ = error_handler.stream.map(op, raise_for_even, on_error=lambda _: None) + _ = error_handler.stream.map(op, return_1, on_error=lambda _: None) assert "Please do not set on_success, on_error, on_finalize as they would be ignored" in str(error.value) async def test_secured_map_stream_double_secure_no_wrap(self): - errored_nums: set[int] = set() - op = stream.iterate([1, 2, 3, 4, 5, 6]) + error_callback, error_tracker = create_callback_tracker() + success_callback, success_tracker = create_callback_tracker() - def store(error: Exception): - nonlocal errored_nums - errored_nums.add(int(str(error))) + op = stream.iterate([1, 2, 3, 4, 5, 6]) - @error_handler.decorator(on_error=store) + @error_handler.decorator(on_error=error_callback, on_success=success_callback) def raise_for_even(num: int) -> int: if num % 2 == 0: - raise ValueError(f"{num}") + raise ValueError(num) return num op = error_handler.stream.map(op, raise_for_even) elements = await stream.list(op) assert set(elements) == {1, 3, 5} + errored_nums = {error.args[0] for (error, _), __ in error_tracker} assert errored_nums == {2, 4, 6} + succeeded_nums = {num_returned for (num_returned, _), __ in success_tracker} + assert succeeded_nums == {1, 3, 5} async def test_secured_map_stream_double_secure_wrap(self): errored_nums_from_map: set[int] = set() errored_nums_from_decorator: set[int] = set() op = stream.iterate([1, 2, 3, 4, 5, 6]) - def store_from_map(error: Exception): + def store_from_map(error: Exception, _: int): nonlocal errored_nums_from_map errored_nums_from_map.add(int(str(error))) - def store_from_decorator(error: Exception): + def store_from_decorator(error: Exception, _: int): nonlocal errored_nums_from_decorator errored_nums_from_decorator.add(int(str(error))) raise error @@ -142,7 +145,7 @@ def raise_for_even(num: int) -> int: raise ValueError(f"{num}") return num - def store(error: Exception): + def store(error: Exception, _: int): nonlocal errored_nums errored_nums.add(int(str(error))) diff --git a/unittests/test_singleton_pattern.py b/unittests/test_singleton_pattern.py index 5f1ccb6..a7449dc 100644 --- a/unittests/test_singleton_pattern.py +++ b/unittests/test_singleton_pattern.py @@ -1,11 +1,11 @@ import pytest -from error_handler.types import Singleton +from error_handler.types import SingletonMeta class TestSingleton: def test_singleton(self): - class MySingleton(metaclass=Singleton): + class MySingleton(metaclass=SingletonMeta): def __init__(self): self.x = 7 @@ -20,7 +20,7 @@ def get_me(self) -> int: def test_singleton_with_args(self): with pytest.raises(AttributeError) as error_info: - class _(metaclass=Singleton): + class _(metaclass=SingletonMeta): def __init__(self, x: int): self.x = x diff --git a/unittests/utils.py b/unittests/utils.py new file mode 100644 index 0000000..8561502 --- /dev/null +++ b/unittests/utils.py @@ -0,0 +1,23 @@ +from typing import Any, Callable + + +def assert_not_called(*_, **__): + raise ValueError("This should not be called") + + +def create_callback_tracker( + additional_callback: Callable = lambda *args, **kwargs: None, +) -> tuple[Callable, list[tuple[tuple[Any, ...], dict[str, Any]]]]: + """ + Creates a callback function taking any arguments and a tracker which stores all calls to the callback in a list. + The callback function will call the additional_callback with the same arguments and return its return value. + It will also store the arguments as tuple in the call_args list. + Returns the callback and the call_args list. + """ + call_args = [] + + def callback(*args, **kwargs): + call_args.append((args, kwargs)) + return additional_callback(*args, **kwargs) + + return callback, call_args