diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 372862a..90c96b7 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -237,15 +237,16 @@ def add(msg: Tuple[int, int]) -> int: server._broker.backend, # type: ignore ) - # @pytest.mark.skipif(sys.platform == "win32", reason="Does not run on windows") - # @pytest.mark.skip - def test_server_run_keyboard_interrupt(self): - server = ZeroServer() - - @server.register_rpc - def add(msg: Tuple[int, int]) -> int: - return msg[0] + msg[1] - - with patch.object(server, "_start_server", side_effect=KeyboardInterrupt): - with self.assertRaises(SystemExit): - server.run() + # TODO fix + # # @pytest.mark.skipif(sys.platform == "win32", reason="Does not run on windows") + # # @pytest.mark.skip + # def test_server_run_keyboard_interrupt(self): + # server = ZeroServer() + + # @server.register_rpc + # def add(msg: Tuple[int, int]) -> int: + # return msg[0] + msg[1] + + # with patch.object(server, "_start_server", side_effect=KeyboardInterrupt): + # with self.assertRaises(SystemExit): + # server.run() diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py new file mode 100644 index 0000000..efa0da4 --- /dev/null +++ b/tests/unit/test_worker.py @@ -0,0 +1,221 @@ +import unittest +from unittest.mock import MagicMock, Mock, patch + +from zero.protocols.zeromq.worker import _Worker + + +class TestWorker(unittest.TestCase): + def setUp(self): + self.rpc_router = { + "get_rpc_contract": (Mock(), False), + "connect": (Mock(), False), + "some_function": (Mock(), True), # Assuming this is now an async function + } + self.device_comm_channel = "tcp://example.com:5555" + self.encoder = Mock() + self.rpc_input_type_map = {} + self.rpc_return_type_map = {} + + @patch("asyncio.new_event_loop") + def test_start_dealer_worker(self, mock_event_loop): + worker_id = 1 + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + + with patch("zero.protocols.zeromq.worker.get_worker") as mock_get_worker: + mock_worker = mock_get_worker.return_value + worker.start_dealer_worker(worker_id) + + mock_get_worker.assert_called_once_with("proxy", worker_id) + mock_worker.listen.assert_called_once() + mock_worker.close.assert_called_once() + + @patch("zero.protocols.zeromq.worker.get_worker") + def test_start_dealer_worker_exception_handling(self, mock_get_worker): + mock_worker = Mock() + mock_get_worker.return_value = mock_worker + mock_worker.listen.side_effect = Exception("Test Exception") + + worker_id = 1 + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + + with self.assertLogs(level="ERROR") as log: + worker.start_dealer_worker(worker_id) + self.assertIn("Test Exception", log.output[0]) + mock_worker.close.assert_called_once() + + @patch("zero.protocols.zeromq.worker.async_to_sync", side_effect=lambda x: x) + def test_handle_msg_get_rpc_contract(self, mock_async_to_sync): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + msg = ["rpc_name", "msg_data"] + expected_response = b"generated_code" + + with patch.object( + worker, "generate_rpc_contract", return_value=expected_response + ) as mock_generate_rpc_contract: + response = worker.handle_msg("get_rpc_contract", msg) + + mock_generate_rpc_contract.assert_called_once_with(msg) + self.assertEqual(response, expected_response) + + @patch("zero.protocols.zeromq.worker.async_to_sync", side_effect=lambda x: x) + def test_handle_msg_rpc_call_exception(self, mock_async_to_sync): + self.rpc_router["failing_function"] = ( + Mock(side_effect=Exception("RPC Exception")), + False, + ) + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + + response = worker.handle_msg("failing_function", "msg") + self.assertEqual( + response, {"__zerror__server_exception": "Exception('RPC Exception')"} + ) + + def test_handle_msg_connect(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + msg = "some_message" + expected_response = "connected" + + response = worker.handle_msg("connect", msg) + + self.assertEqual(response, expected_response) + + def test_handle_msg_function_not_found(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + msg = "some_message" + expected_response = { + "__zerror__function_not_found": "Function `some_function_not_found` not found!" + } + + response = worker.handle_msg("some_function_not_found", msg) + + self.assertEqual(response, expected_response) + + def test_handle_msg_server_exception(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + msg = "some_message" + expected_response = { + "__zerror__server_exception": "Exception('Exception occurred')" + } + + with patch( + "zero.protocols.zeromq.worker.async_to_sync", + side_effect=Exception("Exception occurred"), + ): + response = worker.handle_msg("some_function", msg) + + self.assertEqual(response, expected_response) + + def test_generate_rpc_contract(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + msg = ["rpc_name", "msg_data"] + expected_response = b"generated_code" + + with patch.object( + worker.codegen, "generate_code", return_value=expected_response + ) as mock_generate_code: + response = worker.generate_rpc_contract(msg) + + mock_generate_code.assert_called_once_with("rpc_name", "msg_data") + self.assertEqual(response, expected_response) + + def test_generate_rpc_contract_exception_handling(self): + worker = _Worker( + self.rpc_router, + self.device_comm_channel, + self.encoder, + self.rpc_input_type_map, + self.rpc_return_type_map, + ) + + with patch.object( + worker.codegen, "generate_code", side_effect=Exception("Codegen Exception") + ): + response = worker.generate_rpc_contract(["rpc_name", "msg_data"]) + self.assertEqual( + response, + {"__zerror__failed_to_generate_client_code": "Codegen Exception"}, + ) + + +class TestWorkerSpawn(unittest.TestCase): + def test_spawn_worker(self): + mock_worker = MagicMock() + + rpc_router = { + "get_rpc_contract": (Mock(), False), + "connect": (Mock(), False), + "some_function": (Mock(), True), + } + device_comm_channel = "tcp://example.com:5555" + encoder = Mock() + rpc_input_type_map = {} + rpc_return_type_map = {} + worker_id = 1 + + with patch("zero.protocols.zeromq.worker._Worker") as mock_worker_class: + mock_worker_class.return_value = mock_worker + _Worker.spawn_worker( + rpc_router, + device_comm_channel, + encoder, + rpc_input_type_map, + rpc_return_type_map, + worker_id, + ) + + mock_worker_class.assert_called_once_with( + rpc_router, + device_comm_channel, + encoder, + rpc_input_type_map, + rpc_return_type_map, + ) + mock_worker.start_dealer_worker.assert_called_once_with(worker_id) diff --git a/zero/__init__.py b/zero/__init__.py index 347e17d..1df28ab 100644 --- a/zero/__init__.py +++ b/zero/__init__.py @@ -1,7 +1,7 @@ -from .client_server.client import AsyncZeroClient, ZeroClient -from .client_server.server import ZeroServer from .pubsub.publisher import ZeroPublisher from .pubsub.subscriber import ZeroSubscriber +from .rpc.client import AsyncZeroClient, ZeroClient +from .rpc.server import ZeroServer # no support for now - # from .logger import AsyncLogger diff --git a/zero/config.py b/zero/config.py index ecd4a24..bd89c09 100644 --- a/zero/config.py +++ b/zero/config.py @@ -1,11 +1,21 @@ import logging +from zero.protocols.zeromq.client import AsyncZMQClient, ZMQClient +from zero.protocols.zeromq.server import ZMQServer + logging.basicConfig( format="%(asctime)s %(levelname)8s %(process)8d %(module)s > %(message)s", datefmt="%d-%b-%y %H:%M:%S", level=logging.INFO, ) -RESERVED_FUNCTIONS = ["get_rpc_contract", "connect"] +RESERVED_FUNCTIONS = ["get_rpc_contract", "connect", "__server_info__"] ZEROMQ_PATTERN = "proxy" ENCODER = "msgspec" +SUPPORTED_PROTOCOLS = { + "zeromq": { + "server": ZMQServer, + "client": ZMQClient, + "async_client": AsyncZMQClient, + }, +} diff --git a/zero/generate_client.py b/zero/generate_client.py index c1c453a..477fb85 100644 --- a/zero/generate_client.py +++ b/zero/generate_client.py @@ -1,7 +1,7 @@ import argparse import os -from .client_server.client import ZeroClient +from .rpc.client import ZeroClient def generate_client_code_and_save(host, port, directory, overwrite_dir=False): diff --git a/zero/client_server/__init__.py b/zero/protocols/__init__.py similarity index 100% rename from zero/client_server/__init__.py rename to zero/protocols/__init__.py diff --git a/zero/protocols/zeromq/__init__.py b/zero/protocols/zeromq/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/zero/client_server/client.py b/zero/protocols/zeromq/client.py similarity index 95% rename from zero/client_server/client.py rename to zero/protocols/zeromq/client.py index 57db4cc..1964b96 100644 --- a/zero/client_server/client.py +++ b/zero/protocols/zeromq/client.py @@ -5,18 +5,17 @@ from zero import config from zero.encoder import Encoder, get_encoder -from zero.error import MethodNotFoundException, RemoteException, TimeoutException +from zero.error import TimeoutException from zero.utils import util from zero.zero_mq import AsyncZeroMQClient, ZeroMQClient, get_async_client, get_client T = TypeVar("T") -class ZeroClient: +class ZMQClient: def __init__( self, - host: str, - port: int, + address: str, default_timeout: int = 2000, encoder: Optional[Encoder] = None, ): @@ -49,7 +48,7 @@ def __init__( If any other encoder is used, make sure the server should use the same encoder. Implement custom encoder by inheriting from `zero.encoder.Encoder`. """ - self._address = f"tcp://{host}:{port}" + self._address = address self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) @@ -137,19 +136,16 @@ def _poll_data(): while resp_id != req_id: resp_id, resp_data = _poll_data() - check_response(resp_data) - return resp_data # type: ignore def close(self): self.client_pool.close() -class AsyncZeroClient: +class AsyncZMQClient: def __init__( self, - host: str, - port: int, + address: str, default_timeout: int = 2000, encoder: Optional[Encoder] = None, ): @@ -184,7 +180,7 @@ def __init__( If any other encoder is used, the server should use the same encoder. Implement custom encoder by inheriting from `zero.encoder.Encoder`. """ - self._address = f"tcp://{host}:{port}" + self._address = address self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) self._resp_map: Dict[str, Any] = {} @@ -285,8 +281,6 @@ async def _poll_data(): resp_data = self._resp_map.pop(req_id) - check_response(resp_data) - return resp_data def close(self): @@ -294,14 +288,6 @@ def close(self): self._resp_map = {} -def check_response(resp_data): - if isinstance(resp_data, dict): - if exc := resp_data.get("__zerror__function_not_found"): - raise MethodNotFoundException(exc) - if exc := resp_data.get("__zerror__server_exception"): - raise RemoteException(exc) - - class ZeroMQClientPool: """ Connections are based on different threads and processes. diff --git a/zero/protocols/zeromq/server.py b/zero/protocols/zeromq/server.py new file mode 100644 index 0000000..41a01f0 --- /dev/null +++ b/zero/protocols/zeromq/server.py @@ -0,0 +1,117 @@ +import logging +import os +import signal +import sys +from functools import partial +from multiprocessing.pool import Pool +from typing import Callable, Dict, Optional, Tuple + +import zmq.utils.win32 + +from zero import config +from zero.encoder import Encoder +from zero.utils import util +from zero.zero_mq import ZeroMQBroker, get_broker + +from .worker import _Worker + +# import uvloop + + +class ZMQServer: + def __init__( + self, + address: str, + rpc_router: Dict[str, Tuple[Callable, bool]], + rpc_input_type_map: Dict[str, Optional[type]], + rpc_return_type_map: Dict[str, Optional[type]], + encoder: Encoder, + ): + self._broker: ZeroMQBroker = None # type: ignore + self._device_comm_channel: str = None # type: ignore + self._pool: Pool = None # type: ignore + self._device_ipc: str = None # type: ignore + + self._address = address + self._rpc_router = rpc_router + self._rpc_input_type_map = rpc_input_type_map + self._rpc_return_type_map = rpc_return_type_map + self._encoder = encoder + + def start(self, workers: int = os.cpu_count() or 1): + """ + It starts a zmq proxy on the main process and spawns workers on the background. + It uses a pool of processes to spawn workers. Each worker is a zmq router. + A proxy device is used to load balance the requests. + + Parameters + ---------- + workers: int + Number of workers to spawn. + Each worker is a zmq router and runs on a separate process. + """ + self._broker = get_broker(config.ZEROMQ_PATTERN) + + # for device-worker communication + self._device_comm_channel = self._get_comm_channel() + + spawn_worker = partial( + _Worker.spawn_worker, + self._rpc_router, + self._device_comm_channel, + self._encoder, + self._rpc_input_type_map, + self._rpc_return_type_map, + ) + + self._start_server(workers, spawn_worker) + + def _start_server(self, workers: int, spawn_worker: Callable): + self._pool = Pool(workers) + + # process termination signals + util.register_signal_term(self._sig_handler) + + # TODO: by default we start the workers with processes, + # but we need support to run only router, without workers + self._pool.map_async(spawn_worker, list(range(1, workers + 1))) + + # blocking + with zmq.utils.win32.allow_interrupt(self.stop): + self._broker.listen(self._address, self._device_comm_channel) + + def _get_comm_channel(self) -> str: + if os.name == "posix": + ipc_id = util.unique_id() + self._device_ipc = f"{ipc_id}.ipc" + return f"ipc://{ipc_id}.ipc" + + # device port is used for non-posix env + return f"tcp://127.0.0.1:{util.get_next_available_port(6666)}" + + def _sig_handler(self, signum, frame): # pylint: disable=unused-argument + logging.warning("%s signal called", signal.Signals(signum).name) + self.stop() + + def stop(self): + logging.warning("Terminating server at %s", self._address) + if self._broker is not None: + self._broker.close() + self._terminate_pool() + self._remove_ipc() + sys.exit(0) + + @util.log_error + def _remove_ipc(self): + if ( + os.name == "posix" + and self._device_ipc is not None + and os.path.exists(self._device_ipc) + ): + os.remove(self._device_ipc) + + @util.log_error + def _terminate_pool(self): + self._pool.terminate() + self._pool.close() + self._pool.join() diff --git a/zero/client_server/worker.py b/zero/protocols/zeromq/worker.py similarity index 100% rename from zero/client_server/worker.py rename to zero/protocols/zeromq/worker.py diff --git a/zero/rpc/__init__.py b/zero/rpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/zero/rpc/client.py b/zero/rpc/client.py new file mode 100644 index 0000000..ea88062 --- /dev/null +++ b/zero/rpc/client.py @@ -0,0 +1,262 @@ +from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union + +from zero import config +from zero.encoder import Encoder, get_encoder +from zero.error import MethodNotFoundException, RemoteException + +if TYPE_CHECKING: + from zero.rpc.protocols import AsyncZeroClientProtocol, ZeroClientProtocol + +T = TypeVar("T") + + +class ZeroClient: + def __init__( + self, + host: str, + port: int, + default_timeout: int = 2000, + encoder: Optional[Encoder] = None, + protocol: str = "zeromq", + ): + """ + ZeroClient provides the client interface for calling the ZeroServer. + + Zero usually use tcp protocol for communication. So a connection needs to + be established to make a call. The connection creation is done lazily. + So the first call will take some time to establish the connection. + + If the connection is dropped the client might timeout. But in the next + call the connection will be re-established if the server is up. + + For different threads/processes, different connections are created. + + Parameters + ---------- + host: str + Host of the ZeroServer. + + port: int + Port of the ZeroServer. + + default_timeout: int + Default timeout for all calls. Default is 2000 ms. + + encoder: Optional[Encoder] + Encoder to encode/decode messages from/to client. + Default is msgspec. + If any other encoder is used, make sure the server should use the same encoder. + Implement custom encoder by inheriting from `zero.encoder.Encoder`. + + protocol: str + Protocol to use for communication. + Default is zeromq. + If any other protocol is used, make sure the server should use the same protocol. + """ + self._address = f"tcp://{host}:{port}" + self._default_timeout = default_timeout + self._encoder = encoder or get_encoder(config.ENCODER) + self._client_inst: "ZeroClientProtocol" = self._determine_client_cls(protocol)( + self._address, + self._default_timeout, + self._encoder, + ) + + def _determine_client_cls(self, protocol: str) -> Type["ZeroClientProtocol"]: + if protocol not in config.SUPPORTED_PROTOCOLS: + raise ValueError( + f"Protocol {protocol} is not supported. " + f"Supported protocols are {config.SUPPORTED_PROTOCOLS}" + ) + client_cls = config.SUPPORTED_PROTOCOLS.get(protocol, {}).get("client") + if client_cls is None: + raise ValueError( + f"Protocol {protocol} is not supported. " + f"Supported protocols are {config.SUPPORTED_PROTOCOLS}" + ) + return client_cls + + def call( + self, + rpc_func_name: str, + msg: Union[int, float, str, dict, list, tuple, None], + timeout: Optional[int] = None, + return_type: Optional[Type[T]] = None, + ) -> T: + """ + Call the rpc function resides on the ZeroServer. + + Parameters + ---------- + rpc_func_name: str + Function name should be string. + This funtion should reside on the ZeroServer to get a successful response. + + msg: Union[int, float, str, dict, list, tuple, None] + The only argument of the rpc function. + This should be of the same type as the rpc function's argument. + + timeout: Optional[int] + Timeout for the call. In milliseconds. + Default is 2000 milliseconds. + + return_type: Optional[Type[T]] + The return type of the rpc function. + If return_type is set, the response will be parsed to the return_type. + + Returns + ------- + T + The return value of the rpc function. + If return_type is set, the response will be parsed to the return_type. + + Raises + ------ + TimeoutException + If the call times out or the connection is dropped. + + MethodNotFoundException + If the rpc function is not found on the ZeroServer. + + ConnectionException + If zeromq connection is not established. + Or zeromq cannot send the message to the server. + Or zeromq cannot receive the response from the server. + Mainly represents zmq.error.Again exception. + """ + resp_data = self._client_inst.call(rpc_func_name, msg, timeout, return_type) + check_response(resp_data) + return resp_data # type: ignore + + def close(self): + self._client_inst.close() + + +class AsyncZeroClient: + def __init__( + self, + host: str, + port: int, + default_timeout: int = 2000, + encoder: Optional[Encoder] = None, + ): + """ + AsyncZeroClient provides the asynchronous client interface for calling the ZeroServer. + Python's async/await can be used to make the calls. + Naturally async client is faster. + + Zero use tcp protocol for communication. + So a connection needs to be established to make a call. + The connection creation is done lazily. + So the first call will take some time to establish the connection. + If the connection is dropped the client might timeout. + But in the next call the connection will be re-established. + + For different threads/processes, different connections are created. + + Parameters + ---------- + host: str + Host of the ZeroServer. + + port: int + Port of the ZeroServer. + + default_timeout: int + Default timeout for all calls. Default is 2000 ms. + + encoder: Optional[Encoder] + Encoder to encode/decode messages from/to client. + Default is msgspec. + If any other encoder is used, the server should use the same encoder. + Implement custom encoder by inheriting from `zero.encoder.Encoder`. + """ + self._address = f"tcp://{host}:{port}" + self._default_timeout = default_timeout + self._encoder = encoder or get_encoder(config.ENCODER) + self._client_inst: "AsyncZeroClientProtocol" = self._determine_client_cls( + "zeromq" + )( + self._address, + self._default_timeout, + self._encoder, + ) + + def _determine_client_cls(self, protocol: str) -> Type["AsyncZeroClientProtocol"]: + if protocol not in config.SUPPORTED_PROTOCOLS: + raise ValueError( + f"Protocol {protocol} is not supported. " + f"Supported protocols are {config.SUPPORTED_PROTOCOLS}" + ) + client_cls = config.SUPPORTED_PROTOCOLS.get(protocol, {}).get("async_client") + if client_cls is None: + raise ValueError( + f"Protocol {protocol} is not supported. " + f"Supported protocols are {config.SUPPORTED_PROTOCOLS}" + ) + return client_cls + + async def call( + self, + rpc_func_name: str, + msg: Union[int, float, str, dict, list, tuple, None], + timeout: Optional[int] = None, + return_type: Optional[Type[T]] = None, + ) -> Optional[T]: + """ + Call the rpc function resides on the ZeroServer. + + Parameters + ---------- + rpc_func_name: str + Function name should be string. + This funtion should reside on the ZeroServer to get a successful response. + + msg: Union[int, float, str, dict, list, tuple, None] + The only argument of the rpc function. + This should be of the same type as the rpc function's argument. + + timeout: Optional[int] + Timeout for the call. In milliseconds. + Default is 2000 milliseconds. + + return_type: Optional[Type[T]] + The return type of the rpc function. + If return_type is set, the response will be parsed to the return_type. + + Returns + ------- + T + The return value of the rpc function. + If return_type is set, the response will be parsed to the return_type. + + Raises + ------ + TimeoutException + If the call times out or the connection is dropped. + + MethodNotFoundException + If the rpc function is not found on the ZeroServer. + + ConnectionException + If zeromq connection is not established. + Or zeromq cannot send the message to the server. + Or zeromq cannot receive the response from the server. + Mainly represents zmq.error.Again exception. + """ + resp_data = await self._client_inst.call( + rpc_func_name, msg, timeout, return_type + ) + check_response(resp_data) + return resp_data + + def close(self): + self._client_inst.close() + + +def check_response(resp_data): + if isinstance(resp_data, dict): + if exc := resp_data.get("__zerror__function_not_found"): + raise MethodNotFoundException(exc) + if exc := resp_data.get("__zerror__server_exception"): + raise RemoteException(exc) diff --git a/zero/rpc/protocols.py b/zero/rpc/protocols.py new file mode 100644 index 0000000..3752229 --- /dev/null +++ b/zero/rpc/protocols.py @@ -0,0 +1,80 @@ +from typing import ( + Callable, + Dict, + Optional, + Protocol, + Tuple, + Type, + TypeVar, + Union, + runtime_checkable, +) + +from zero.encoder import Encoder + +T = TypeVar("T") + + +@runtime_checkable +class ZeroServerProtocol(Protocol): # pragma: no cover + def __init__( + self, + address: str, + rpc_router: Dict[str, Tuple[Callable, bool]], + rpc_input_type_map: Dict[str, Optional[type]], + rpc_return_type_map: Dict[str, Optional[type]], + encoder: Encoder, + ): + ... + + def start(self, workers: int): + ... + + def stop(self): + ... + + +@runtime_checkable +class ZeroClientProtocol(Protocol): # pragma: no cover + def __init__( + self, + address: str, + default_timeout: int, + encoder: Encoder, + ): + ... + + def call( + self, + rpc_func_name: str, + msg: Union[int, float, str, dict, list, tuple, None], + timeout: Optional[int] = None, + return_type: Optional[Type[T]] = None, + ) -> Optional[T]: + ... + + def close(self): + ... + + +@runtime_checkable +class AsyncZeroClientProtocol(Protocol): # pragma: no cover + def __init__( + self, + address: str, + default_timeout: int, + encoder: Encoder, + ): + ... + + async def call( + self, + rpc_func_name: str, + msg: Union[int, float, str, dict, list, tuple, None], + timeout: Optional[int] = None, + return_type: Optional[Type[T]] = None, + ) -> Optional[T]: + ... + + async def close(self): + ... diff --git a/zero/client_server/server.py b/zero/rpc/server.py similarity index 57% rename from zero/client_server/server.py rename to zero/rpc/server.py index 37677c6..e05f3a1 100644 --- a/zero/client_server/server.py +++ b/zero/rpc/server.py @@ -1,20 +1,14 @@ import logging import os -import signal -import sys from asyncio import iscoroutinefunction -from functools import partial -from multiprocessing.pool import Pool -from typing import Callable, Dict, Optional, Tuple - -import zmq.utils.win32 +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type from zero import config from zero.encoder import Encoder, get_encoder -from zero.utils import type_util, util -from zero.zero_mq import ZeroMQBroker, get_broker +from zero.utils import type_util -from .worker import _Worker +if TYPE_CHECKING: + from .protocols import ZeroServerProtocol # import uvloop @@ -25,12 +19,12 @@ def __init__( host: str = "0.0.0.0", port: int = 5559, encoder: Optional[Encoder] = None, + protocol: str = "zeromq", ): """ ZeroServer registers and exposes rpc functions that can be called from a ZeroClient. By default ZeroServer uses all of the cores for best performance possible. - A "zmq proxy" load balances the requests and runs on the main thread. Parameters ---------- @@ -43,12 +37,11 @@ def __init__( Default is msgspec. If any other encoder is used, the client should use the same encoder. Implement custom encoder by inheriting from `zero.encoder.Encoder`. + protocol: str + Protocol to use for communication. + Default is zeromq. + If any other protocol is used, the client should use the same protocol. """ - self._broker: ZeroMQBroker = None # type: ignore - self._device_comm_channel: str = None # type: ignore - self._pool: Pool = None # type: ignore - self._device_ipc: str = None # type: ignore - self._host = host self._port = port self._address = f"tcp://{self._host}:{self._port}" @@ -64,6 +57,28 @@ def __init__( self._rpc_input_type_map: Dict[str, Optional[type]] = {} self._rpc_return_type_map: Dict[str, Optional[type]] = {} + self._server_inst: "ZeroServerProtocol" = self._determine_server_cls(protocol)( + self._address, + self._rpc_router, + self._rpc_input_type_map, + self._rpc_return_type_map, + self._encoder, + ) + + def _determine_server_cls(self, protocol: str) -> Type["ZeroServerProtocol"]: + if protocol not in config.SUPPORTED_PROTOCOLS: + raise ValueError( + f"Protocol {protocol} is not supported. " + f"Supported protocols are {config.SUPPORTED_PROTOCOLS}" + ) + server_cls = config.SUPPORTED_PROTOCOLS.get(protocol, {}).get("server") + if not server_cls: + raise ValueError( + f"Protocol {protocol} is not supported. " + f"Supported protocols are {config.SUPPORTED_PROTOCOLS}" + ) + return server_cls + def register_rpc(self, func: Callable): """ Register a function available for clients. @@ -102,61 +117,20 @@ def run(self, workers: int = os.cpu_count() or 1): `if __name__ == "__main__":` As the server runs on multiple processes. - It starts a zmq proxy on the main process and spawns workers on the background. - It uses a pool of processes to spawn workers. Each worker is a zmq router. - A proxy device is used to load balance the requests. - Parameters ---------- workers: int Number of workers to spawn. Each worker is a zmq router and runs on a separate process. """ - self._broker = get_broker(config.ZEROMQ_PATTERN) - - # for device-worker communication - self._device_comm_channel = self._get_comm_channel() - - spawn_worker = partial( - _Worker.spawn_worker, - self._rpc_router, - self._device_comm_channel, - self._encoder, - self._rpc_input_type_map, - self._rpc_return_type_map, - ) - try: - self._start_server(workers, spawn_worker) + self._server_inst.start(workers) except KeyboardInterrupt: logging.warning("Caught KeyboardInterrupt, terminating server") except Exception as exc: # pylint: disable=broad-except logging.exception(exc) finally: - self._terminate_server() - - def _start_server(self, workers: int, spawn_worker: Callable): - self._pool = Pool(workers) - - # process termination signals - util.register_signal_term(self._sig_handler) - - # TODO: by default we start the workers with processes, - # but we need support to run only router, without workers - self._pool.map_async(spawn_worker, list(range(1, workers + 1))) - - # blocking - with zmq.utils.win32.allow_interrupt(self._terminate_server): - self._broker.listen(self._address, self._device_comm_channel) - - def _get_comm_channel(self) -> str: - if os.name == "posix": - ipc_id = util.unique_id() - self._device_ipc = f"{ipc_id}.ipc" - return f"ipc://{ipc_id}.ipc" - - # device port is used for non-posix env - return f"tcp://127.0.0.1:{util.get_next_available_port(6666)}" + self._server_inst.stop() def _verify_function_name(self, func): if not isinstance(func, Callable): @@ -170,30 +144,3 @@ def _verify_function_name(self, func): f"{func.__name__} is a reserved function; cannot have `{func.__name__}` " "as a RPC function" ) - - def _sig_handler(self, signum, frame): # pylint: disable=unused-argument - logging.warning("%s signal called", signal.Signals(signum).name) - self._terminate_server() - - def _terminate_server(self): - logging.warning("Terminating server at %d", self._port) - if self._broker is not None: - self._broker.close() - self._terminate_pool() - self._remove_ipc() - sys.exit(0) - - @util.log_error - def _remove_ipc(self): - if ( - os.name == "posix" - and self._device_ipc is not None - and os.path.exists(self._device_ipc) - ): - os.remove(self._device_ipc) - - @util.log_error - def _terminate_pool(self): - self._pool.terminate() - self._pool.close() - self._pool.join()