Skip to content

Commit

Permalink
Async Stub Annotations (#611)
Browse files Browse the repository at this point in the history
* Generate a set of TypeVars and a generic class

TypeVars have defaults to match expected behavior

Overload init methods to get expected types back

Create AsyncStub type alias

Signed-off-by: Aidan Jensen <[email protected]>

---------

Signed-off-by: Aidan Jensen <[email protected]>
  • Loading branch information
artificial-aidan authored Oct 21, 2024
1 parent f20607f commit 54d184c
Show file tree
Hide file tree
Showing 10 changed files with 311 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ per-file-ignores =
*.py: E203, E301, E302, E305, E501
*.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037
*_pb2.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021
*_pb2_grpc.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021
*_pb2_grpc.pyi: E301, E302, E305, E501, E701, E741, F401, F403, F405, F822, Y037, Y021, Y023

extend_exclude = venv*,*_pb2.py,*_pb2_grpc.py,build/
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

- Mark top-level mangled identifiers as `TypeAlias`.
- Change the top-level mangling prefix from `global___` to `Global___` to respect
[Y042](https://github.com/PyCQA/flake8-pyi/blob/main/ERRORCODES.md#list-of-warnings) naming convention.
[Y042](https://github.com/PyCQA/flake8-pyi/blob/main/ERRORCODES.md#list-of-warnings) naming convention.
- Support client stub async typing overloads

## 3.6.0

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ black .
- [@fergyfresh](https://github.com/fergyfresh)
- [@AlexWaygood](https://github.com/AlexWaygood)
- [@Avasam](https://github.com/Avasam)
- [@artificial-aidan](https://github.com/artificial-aidan)

## Licence etc.

Expand Down
107 changes: 83 additions & 24 deletions mypy_protobuf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
Iterator,
List,
Optional,
Set,
Sequence,
Set,
Tuple,
)

import google.protobuf.descriptor_pb2 as d
from google.protobuf.compiler import plugin_pb2 as plugin_pb2
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
from google.protobuf.internal.well_known_types import WKTBASES

from . import extensions_pb2

__version__ = "3.6.0"
Expand Down Expand Up @@ -85,6 +86,11 @@
}


def _build_typevar_name(service_name: str, method_name: str) -> str:
# Prefix with underscore to avoid public api error: https://stackoverflow.com/a/78871465
return f"_{service_name}{method_name}Type"


def _mangle_global_identifier(name: str) -> str:
"""
Module level identifiers are mangled and aliased so that they can be disambiguated
Expand Down Expand Up @@ -168,9 +174,7 @@ def _import(self, path: str, name: str) -> str:
eg. self._import("typing", "Literal") -> "Literal"
"""
if path == "typing_extensions":
stabilization = {
"TypeAlias": (3, 10),
}
stabilization = {"TypeAlias": (3, 10), "TypeVar": (3, 13)}
assert name in stabilization
if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
self.typing_extensions_min = stabilization[name]
Expand Down Expand Up @@ -732,6 +736,46 @@ def write_grpc_async_hacks(self) -> None:
wl("...")
wl("")

def write_grpc_type_vars(self, service: d.ServiceDescriptorProto) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
return
for _, method in methods:
wl("{} = {}(", _build_typevar_name(service.name, method.name), self._import("typing_extensions", "TypeVar"))
with self._indent():
wl("'{}',", _build_typevar_name(service.name, method.name))
wl("{}[", self._callable_type(method, is_async=False))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")
wl("{}[", self._callable_type(method, is_async=True))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")
wl("default={}[", self._callable_type(method, is_async=False))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")
wl(")")
wl("")

def write_self_types(self, service: d.ServiceDescriptorProto, is_async: bool) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
if not methods:
return
for _, method in methods:
with self._indent():
wl("{}[", self._callable_type(method, is_async=is_async))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("],")

def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
wl = self._write_line
methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
Expand Down Expand Up @@ -769,11 +813,7 @@ def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix:
for i, method in methods:
scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]

wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
with self._indent():
wl("{},", self._input_type(method))
wl("{},", self._output_type(method))
wl("]")
wl("{}: {}", method.name, f"{_build_typevar_name(service.name, method.name)}")
self._write_comments(scl)
wl("")

Expand All @@ -791,29 +831,48 @@ def write_grpc_services(

scl = scl_prefix + [i]

# Type vars
self.write_grpc_type_vars(service)

# The stub client
class_name = f"{service.name}Stub"
wl(
"class {}Stub:",
service.name,
"class {}({}[{}]):",
class_name,
self._import("typing", "Generic"),
", ".join(f"{_build_typevar_name(service.name, method.name)}" for method in service.method),
)
with self._indent():
if self._write_comments(scl):
wl("")
# To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
wl("def __init__(self, channel: {}) -> None: ...", channel)

# Write sync overload
wl("@{}", self._import("typing", "overload"))
wl("def __init__(self: {}[", class_name)
self.write_self_types(service, False)
wl(
"], channel: {}) -> None: ...",
self._import("grpc", "Channel"),
)
wl("")

# Write async overload
wl("@{}", self._import("typing", "overload"))
wl("def __init__(self: {}[", class_name)
self.write_self_types(service, True)
wl(
"], channel: {}) -> None: ...",
self._import("grpc.aio", "Channel"),
)
wl("")

self.write_grpc_stub_methods(service, scl)

# The (fake) async stub client
wl(
"class {}AsyncStub:",
service.name,
)
with self._indent():
if self._write_comments(scl):
wl("")
# No __init__ since this isn't a real class (yet), and requires manual casting to work.
self.write_grpc_stub_methods(service, scl, is_async=True)
# Write AsyncStub alias
wl("{}AsyncStub: {} = {}[", service.name, self._import("typing_extensions", "TypeAlias"), class_name)
self.write_self_types(service, True)
wl("]")
wl("")

# The service definition interface
wl(
Expand Down
3 changes: 2 additions & 1 deletion run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ for PY_VER in $PY_VER_UNIT_TESTS; do
# Write output to file. Make variant w/ omitted line numbers for easy diffing / CR
PY_VER_MYPY_TARGET=$(echo "$1" | cut -d. -f1-2)
export MYPYPATH=$MYPYPATH:test/generated
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="venv_$1/bin/python" --python-version="$PY_VER_MYPY_TARGET" "${@: 2}" > "$MYPY_OUTPUT/mypy_output" || true
# Use --no-incremental to avoid caching issues: https://github.com/python/mypy/issues/16363
mypy --custom-typeshed-dir="$CUSTOM_TYPESHED_DIR" --python-executable="venv_$1/bin/python" --no-incremental --python-version="$PY_VER_MYPY_TARGET" "${@: 2}" > "$MYPY_OUTPUT/mypy_output" || true
cut -d: -f1,3- "$MYPY_OUTPUT/mypy_output" > "$MYPY_OUTPUT/mypy_output.omit_linenos"
}

Expand Down
153 changes: 121 additions & 32 deletions test/generated/testproto/grpc/dummy_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ import abc
import collections.abc
import grpc
import grpc.aio
import sys
import testproto.grpc.dummy_pb2
import typing

if sys.version_info >= (3, 13):
import typing as typing_extensions
else:
import typing_extensions

_T = typing.TypeVar("_T")

class _MaybeAsyncIterator(collections.abc.AsyncIterator[_T], collections.abc.Iterator[_T], metaclass=abc.ABCMeta): ...
Expand All @@ -19,60 +25,143 @@ class _ServicerContext(grpc.ServicerContext, grpc.aio.ServicerContext): # type:

GRPC_GENERATED_VERSION: str
GRPC_VERSION: str
class DummyServiceStub:
"""DummyService"""

def __init__(self, channel: typing.Union[grpc.Channel, grpc.aio.Channel]) -> None: ...
UnaryUnary: grpc.UnaryUnaryMultiCallable[
_DummyServiceUnaryUnaryType = typing_extensions.TypeVar(
'_DummyServiceUnaryUnaryType',
grpc.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""UnaryUnary"""
],
grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

UnaryStream: grpc.UnaryStreamMultiCallable[
_DummyServiceUnaryStreamType = typing_extensions.TypeVar(
'_DummyServiceUnaryStreamType',
grpc.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""UnaryStream"""
],
grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

StreamUnary: grpc.StreamUnaryMultiCallable[
_DummyServiceStreamUnaryType = typing_extensions.TypeVar(
'_DummyServiceStreamUnaryType',
grpc.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""StreamUnary"""
],
grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

StreamStream: grpc.StreamStreamMultiCallable[
_DummyServiceStreamStreamType = typing_extensions.TypeVar(
'_DummyServiceStreamStreamType',
grpc.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""StreamStream"""
],
grpc.aio.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
default=grpc.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
)

class DummyServiceAsyncStub:
class DummyServiceStub(typing.Generic[_DummyServiceUnaryUnaryType, _DummyServiceUnaryStreamType, _DummyServiceStreamUnaryType, _DummyServiceStreamStreamType]):
"""DummyService"""

UnaryUnary: grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
@typing.overload
def __init__(self: DummyServiceStub[
grpc.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
], channel: grpc.Channel) -> None: ...

@typing.overload
def __init__(self: DummyServiceStub[
grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
], channel: grpc.aio.Channel) -> None: ...

UnaryUnary: _DummyServiceUnaryUnaryType
"""UnaryUnary"""

UnaryStream: grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
UnaryStream: _DummyServiceUnaryStreamType
"""UnaryStream"""

StreamUnary: grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
StreamUnary: _DummyServiceStreamUnaryType
"""StreamUnary"""

StreamStream: grpc.aio.StreamStreamMultiCallable[
StreamStream: _DummyServiceStreamStreamType
"""StreamStream"""

DummyServiceAsyncStub: typing_extensions.TypeAlias = DummyServiceStub[
grpc.aio.UnaryUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
]
"""StreamStream"""
],
grpc.aio.UnaryStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamUnaryMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
grpc.aio.StreamStreamMultiCallable[
testproto.grpc.dummy_pb2.DummyRequest,
testproto.grpc.dummy_pb2.DummyReply,
],
]

class DummyServiceServicer(metaclass=abc.ABCMeta):
"""DummyService"""
Expand Down
Loading

0 comments on commit 54d184c

Please sign in to comment.