Skip to content

Commit

Permalink
Typing: Improve annotations of process functions (#6077)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell authored Sep 13, 2023
1 parent f41c8ac commit a85af4f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 6 deletions.
46 changes: 42 additions & 4 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@
from aiida.orm.utils.mixins import FunctionCalculationMixin

from .process import Process
from .process_spec import ProcessSpec

try:
UnionType = types.UnionType
except AttributeError:
# This type is not available for Python 3.9 and older
UnionType = None # type: ignore[assignment,misc] # pylint: disable=invalid-name

try:
from typing import ParamSpec
except ImportError:
# Fallback for Python 3.9 and older
from typing_extensions import ParamSpec # type: ignore[assignment]

try:
get_annotations = inspect.get_annotations
except AttributeError:
Expand Down Expand Up @@ -87,7 +94,38 @@ def get_stack_size(size: int = 2) -> int: # type: ignore[return]
return size - 1


def calcfunction(function: FunctionType) -> FunctionType:
P = ParamSpec('P')
R_co = t.TypeVar('R_co', covariant=True)
N = t.TypeVar('N', bound=ProcessNode)


class ProcessFunctionType(t.Protocol, t.Generic[P, R_co, N]):
"""Protocol for a decorated process function."""

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
...

def run(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
...

def run_get_pk(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, int]:
...

def run_get_node(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, N]:
...

is_process_function: bool

node_class: t.Type[N]

process_class: t.Type[Process]

recreate_from: t.Callable[[N], Process]

spec: t.Callable[[], ProcessSpec]


def calcfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, CalcFunctionNode]:
"""
A decorator to turn a standard python function into a calcfunction.
Example usage:
Expand All @@ -111,10 +149,10 @@ def calcfunction(function: FunctionType) -> FunctionType:
:param function: The function to decorate.
:return: The decorated function.
"""
return process_function(node_class=CalcFunctionNode)(function)
return process_function(node_class=CalcFunctionNode)(function) # type: ignore[arg-type]


def workfunction(function: FunctionType) -> FunctionType:
def workfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, WorkFunctionNode]:
"""
A decorator to turn a standard python function into a workfunction.
Example usage:
Expand All @@ -138,7 +176,7 @@ def workfunction(function: FunctionType) -> FunctionType:
:param function: The function to decorate.
:return: The decorated function.
"""
return process_function(node_class=WorkFunctionNode)(function)
return process_function(node_class=WorkFunctionNode)(function) # type: ignore[arg-type]


def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]:
Expand Down
2 changes: 1 addition & 1 deletion aiida/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parse_calcfunction(**kwargs):
inputs = {'metadata': {'store_provenance': store_provenance}}
inputs.update(parser.get_outputs_for_parsing())

return parse_calcfunction.run_get_node(**inputs) # type: ignore[attr-defined]
return parse_calcfunction.run_get_node(**inputs)

@abstractmethod
def parse(self, **kwargs) -> Optional[ExitCode]:
Expand Down
11 changes: 10 additions & 1 deletion docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ py:class aiida.storage.sqlite_zip.models.DbGroupNode
py:class aiida.engine.processes.workchains.context.ToContext
py:func aiida.orm.implementation.BackendQueryBuilder

### Typing aliases
py:obj aiida.engine.processes.functions.P
py:obj aiida.engine.processes.functions.N
py:obj aiida.engine.processes.functions.R_co
py:class P
py:class N
py:class aiida.engine.processes.functions.N
py:class aiida.engine.processes.functions.R_co

### third-party packages
# Note: These exceptions are needed if
# * the objects are referenced e.g. as param/return types types in method docstrings (without intersphinx mapping)
Expand All @@ -117,7 +126,7 @@ py:class click.types.Path
py:class click.types.File
py:class click.types.StringParamType
py:func click.shell_completion._start_of_option
py:meth click.Option.get_default
py:meth click.Option.get_default
py:meth fail

py:class requests.models.Response
Expand Down

0 comments on commit a85af4f

Please sign in to comment.