From a85af4f0c017b8c03426ef7927163a33add08004 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 13 Sep 2023 15:52:28 +0200 Subject: [PATCH] Typing: Improve annotations of process functions (#6077) --- aiida/engine/processes/functions.py | 46 ++++++++++++++++++++++++++--- aiida/parsers/parser.py | 2 +- docs/source/nitpick-exceptions | 11 ++++++- 3 files changed, 53 insertions(+), 6 deletions(-) diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index cfe237e230..c91faf838e 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -41,6 +41,7 @@ from aiida.orm.utils.mixins import FunctionCalculationMixin from .process import Process +from .process_spec import ProcessSpec try: UnionType = types.UnionType @@ -48,6 +49,12 @@ # This type is not available for Python 3.9 and older UnionType = None # type: ignore[assignment,misc] # pylint: disable=invalid-name +try: + from typing import ParamSpec +except ImportError: + # Fallback for Python 3.9 and older + from typing_extensions import ParamSpec # type: ignore[assignment] + try: get_annotations = inspect.get_annotations except AttributeError: @@ -87,7 +94,38 @@ def get_stack_size(size: int = 2) -> int: # type: ignore[return] return size - 1 -def calcfunction(function: FunctionType) -> FunctionType: +P = ParamSpec('P') +R_co = t.TypeVar('R_co', covariant=True) +N = t.TypeVar('N', bound=ProcessNode) + + +class ProcessFunctionType(t.Protocol, t.Generic[P, R_co, N]): + """Protocol for a decorated process function.""" + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co: + ... + + def run(self, *args: P.args, **kwargs: P.kwargs) -> R_co: + ... + + def run_get_pk(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, int]: + ... + + def run_get_node(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, N]: + ... + + is_process_function: bool + + node_class: t.Type[N] + + process_class: t.Type[Process] + + recreate_from: t.Callable[[N], Process] + + spec: t.Callable[[], ProcessSpec] + + +def calcfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, CalcFunctionNode]: """ A decorator to turn a standard python function into a calcfunction. Example usage: @@ -111,10 +149,10 @@ def calcfunction(function: FunctionType) -> FunctionType: :param function: The function to decorate. :return: The decorated function. """ - return process_function(node_class=CalcFunctionNode)(function) + return process_function(node_class=CalcFunctionNode)(function) # type: ignore[arg-type] -def workfunction(function: FunctionType) -> FunctionType: +def workfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, WorkFunctionNode]: """ A decorator to turn a standard python function into a workfunction. Example usage: @@ -138,7 +176,7 @@ def workfunction(function: FunctionType) -> FunctionType: :param function: The function to decorate. :return: The decorated function. """ - return process_function(node_class=WorkFunctionNode)(function) + return process_function(node_class=WorkFunctionNode)(function) # type: ignore[arg-type] def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]: diff --git a/aiida/parsers/parser.py b/aiida/parsers/parser.py index 92e988a778..943ff1ae5b 100644 --- a/aiida/parsers/parser.py +++ b/aiida/parsers/parser.py @@ -167,7 +167,7 @@ def parse_calcfunction(**kwargs): inputs = {'metadata': {'store_provenance': store_provenance}} inputs.update(parser.get_outputs_for_parsing()) - return parse_calcfunction.run_get_node(**inputs) # type: ignore[attr-defined] + return parse_calcfunction.run_get_node(**inputs) @abstractmethod def parse(self, **kwargs) -> Optional[ExitCode]: diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index d9286bc889..da1eebfa24 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -103,6 +103,15 @@ py:class aiida.storage.sqlite_zip.models.DbGroupNode py:class aiida.engine.processes.workchains.context.ToContext py:func aiida.orm.implementation.BackendQueryBuilder +### Typing aliases +py:obj aiida.engine.processes.functions.P +py:obj aiida.engine.processes.functions.N +py:obj aiida.engine.processes.functions.R_co +py:class P +py:class N +py:class aiida.engine.processes.functions.N +py:class aiida.engine.processes.functions.R_co + ### third-party packages # Note: These exceptions are needed if # * the objects are referenced e.g. as param/return types types in method docstrings (without intersphinx mapping) @@ -117,7 +126,7 @@ py:class click.types.Path py:class click.types.File py:class click.types.StringParamType py:func click.shell_completion._start_of_option -py:meth click.Option.get_default +py:meth click.Option.get_default py:meth fail py:class requests.models.Response