diff --git a/setup.py b/setup.py index 6d40934..b2acf92 100644 --- a/setup.py +++ b/setup.py @@ -21,5 +21,12 @@ ], python_requires='>=3.8', package_dir={'':'.'}, - install_requires=['pyzmq', 'msgpack'] -) \ No newline at end of file + install_requires=['pyzmq', 'msgpack'], + extras_require={ + 'test': [ + 'pytest', + 'pytest-asyncio', + 'pyjwt' + ] + } +) diff --git a/tests/register_rpc_method_test.py b/tests/register_rpc_method_test.py new file mode 100644 index 0000000..b37b1f6 --- /dev/null +++ b/tests/register_rpc_method_test.py @@ -0,0 +1,45 @@ +import pytest +from zero import ZeroServer +from zero.errors import ZeroException + + +class DummyService: + + def hello(self): + return 'world' + + def say_no(self) -> str: + return 'no' + + def say_yes(self, please: bool = False) -> str: + if not please: + return "ask nicely." + return "yes" + + @staticmethod + def ping() -> str: + return pong + + @classmethod + def name(cls) -> str: + return cls.__name__ + + +def test_methods(): + app = ZeroServer() + service = DummyService() + app.register_rpc(service.say_no) + app.register_rpc(service.say_yes) + app.register_rpc(service.ping) + app.register_rpc(service.name) + + +def test_methods_no_args(): + app = ZeroServer() + service = DummyService() + with pytest.raises(ZeroException) as e: + app.register_rpc(service.hello) + assert ( + str(e.value) + == "`hello` has no return type hinting; RPC functions must have type hints" + ) diff --git a/tests/register_rpc_test.py b/tests/register_rpc_test.py index c06efec..f0bcc73 100644 --- a/tests/register_rpc_test.py +++ b/tests/register_rpc_test.py @@ -72,7 +72,7 @@ def test_function_with_1_arg_no_typing(): app.register_rpc(function_with_1_arg_no_typing) assert ( str(e.value) - == "`function_with_1_arg_no_typing` has no type hinting; RPC functions must have type hints" + == "`function_with_1_arg_no_typing` argument `a` is not typed." ) diff --git a/zero/server.py b/zero/server.py index 1aef61e..39c2600 100644 --- a/zero/server.py +++ b/zero/server.py @@ -16,14 +16,7 @@ from .codegen import CodeGen from .common import get_next_available_port -from .type_util import ( - get_function_input_class, - get_function_return_class, - verify_allowed_type, - verify_function_args, - verify_function_input_type, - verify_function_return, -) +from .type_util import verify_function_typing from .zero_mq import ZeroMQ # import uvloop @@ -83,13 +76,16 @@ def register_rpc(self, func: typing.Callable): if func.__name__ == "get_rpc_contract": raise Exception("get_rpc_contract is a reserved function; cannot have `get_rpc_contract` as a RPC function") - verify_function_args(func) - verify_function_input_type(func) - verify_function_return(func) + + signature = verify_function_typing(func) + if signature.parameters: + input_type = tuple(signature.parameters.values())[0].annotation + else: + input_type = None self._rpc_router[func.__name__] = func - self._rpc_input_type_map[func.__name__] = get_function_input_class(func) - self._rpc_return_type_map[func.__name__] = get_function_return_class(func) + self._rpc_input_type_map[func.__name__] = input_type + self._rpc_return_type_map[func.__name__] = signature.return_annotation def run(self): try: diff --git a/zero/type_util.py b/zero/type_util.py index 079ce86..b0235fa 100644 --- a/zero/type_util.py +++ b/zero/type_util.py @@ -1,4 +1,5 @@ import typing +import inspect from zero.errors import ZeroException # from pydantic import BaseModel @@ -30,32 +31,49 @@ allowed_types = basic_types + typing_types + special_types + pydantic_types -def verify_function_args(func: typing.Callable): - arg_count = func.__code__.co_argcount +def verify_function_typing(func: typing.Callable): + signature = inspect.signature(func) + arg_count = len(signature.parameters) + if arg_count > 1: raise ZeroException( - f"`{func.__name__}` has more than 1 args; RPC functions can have only one arg - msg, or no arg" + f"`{func.__name__}` has more than 1 args; " + "RPC functions can have only one arg - msg, or no arg" ) - if arg_count == 1: - arg_name = func.__code__.co_varnames[0] - func_arg_type = typing.get_type_hints(func) - if arg_name not in func_arg_type: - raise ZeroException(f"`{func.__name__}` has no type hinting; RPC functions must have type hints") - + for name, param in signature.parameters.items(): + if param.annotation is inspect._empty: + raise ZeroException( + f"`{func.__name__}` argument `{name}` is not typed." + ) + if not param.annotation in allowed_types: + raise ZeroException( + f"`{func.__name__}` argument `{name}` type is not supported." + ) + + if signature.return_annotation is inspect._empty: + raise ZeroException( + f"`{func.__name__}` has no return type hinting; " + "RPC functions must have type hints" + ) + elif not signature.return_annotation in allowed_types: + raise ZeroException( + f"`{func.__name__}` return type is not supported." + ) -def verify_function_return(func: typing.Callable): - types = typing.get_type_hints(func) - if not types.get("return"): - raise ZeroException(f"`{func.__name__}` has no return type hinting; RPC functions must have type hints") + return signature def get_function_input_class(func: typing.Callable): arg_count = func.__code__.co_argcount - if arg_count == 0: + if inspect.ismethod(func): + max_argcount = 2 + else: + max_argcount = 1 + if arg_count == max_argcount - 1: return None - if arg_count == 1: - arg_name = func.__code__.co_varnames[0] + if arg_count == max_argcount: + arg_name = func.__code__.co_varnames[max_argcount - 1] func_arg_type = typing.get_type_hints(func) return func_arg_type[arg_name]