Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👌 Improve calcfunction/workfunction typing #6077

Merged
merged 2 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@
from aiida.orm.utils.mixins import FunctionCalculationMixin

from .process import Process
from .process_spec import ProcessSpec

try:
UnionType = types.UnionType
except AttributeError:
# This type is not available for Python 3.9 and older
UnionType = None # type: ignore[assignment,misc] # pylint: disable=invalid-name

try:
from typing import ParamSpec
except ImportError:
# Fallback for Python 3.9 and older
from typing_extensions import ParamSpec # type: ignore[assignment]

try:
get_annotations = inspect.get_annotations
except AttributeError:
Expand Down Expand Up @@ -87,7 +94,38 @@ def get_stack_size(size: int = 2) -> int: # type: ignore[return]
return size - 1


def calcfunction(function: FunctionType) -> FunctionType:
P = ParamSpec('P')
R_co = t.TypeVar('R_co', covariant=True)
N = t.TypeVar('N', bound=ProcessNode)


class ProcessFunctionType(t.Protocol, t.Generic[P, R_co, N]):
"""Protocol for a decorated process function."""

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
...

def run(self, *args: P.args, **kwargs: P.kwargs) -> R_co:
...

def run_get_pk(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, int]:
...

def run_get_node(self, *args: P.args, **kwargs: P.kwargs) -> tuple[dict[str, t.Any] | None, N]:
...

is_process_function: bool

node_class: t.Type[N]

process_class: t.Type[Process]

recreate_from: t.Callable[[N], Process]

spec: t.Callable[[], ProcessSpec]


def calcfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, CalcFunctionNode]:
"""
A decorator to turn a standard python function into a calcfunction.
Example usage:
Expand All @@ -111,10 +149,10 @@ def calcfunction(function: FunctionType) -> FunctionType:
:param function: The function to decorate.
:return: The decorated function.
"""
return process_function(node_class=CalcFunctionNode)(function)
return process_function(node_class=CalcFunctionNode)(function) # type: ignore[arg-type]


def workfunction(function: FunctionType) -> FunctionType:
def workfunction(function: t.Callable[P, R_co]) -> ProcessFunctionType[P, R_co, WorkFunctionNode]:
"""
A decorator to turn a standard python function into a workfunction.
Example usage:
Expand All @@ -138,7 +176,7 @@ def workfunction(function: FunctionType) -> FunctionType:
:param function: The function to decorate.
:return: The decorated function.
"""
return process_function(node_class=WorkFunctionNode)(function)
return process_function(node_class=WorkFunctionNode)(function) # type: ignore[arg-type]


def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]:
Expand Down
2 changes: 1 addition & 1 deletion aiida/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parse_calcfunction(**kwargs):
inputs = {'metadata': {'store_provenance': store_provenance}}
inputs.update(parser.get_outputs_for_parsing())

return parse_calcfunction.run_get_node(**inputs) # type: ignore[attr-defined]
return parse_calcfunction.run_get_node(**inputs)

@abstractmethod
def parse(self, **kwargs) -> Optional[ExitCode]:
Expand Down
11 changes: 10 additions & 1 deletion docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ py:class aiida.storage.sqlite_zip.models.DbGroupNode
py:class aiida.engine.processes.workchains.context.ToContext
py:func aiida.orm.implementation.BackendQueryBuilder

### Typing aliases
py:obj aiida.engine.processes.functions.P
py:obj aiida.engine.processes.functions.N
py:obj aiida.engine.processes.functions.R_co
py:class P
py:class N
py:class aiida.engine.processes.functions.N
py:class aiida.engine.processes.functions.R_co

### third-party packages
# Note: These exceptions are needed if
# * the objects are referenced e.g. as param/return types types in method docstrings (without intersphinx mapping)
Expand All @@ -117,7 +126,7 @@ py:class click.types.Path
py:class click.types.File
py:class click.types.StringParamType
py:func click.shell_completion._start_of_option
py:meth click.Option.get_default
py:meth click.Option.get_default
py:meth fail

py:class requests.models.Response
Expand Down
Loading