From 8906ecf8f56f01aca0f5bf8f471c4bac0161c931 Mon Sep 17 00:00:00 2001 From: Azizul Haque Ananto Date: Tue, 25 Jun 2024 08:10:26 +0200 Subject: [PATCH] Add schema validation --- examples/basic/client.py | 36 ++++ examples/basic/schema.py | 9 + examples/basic/server.py | 17 +- tests/concurrency/rps_async.py | 3 +- tests/concurrency/rps_sync.py | 3 +- .../single_server/client_server_test.py | 14 +- tests/unit/test_server.py | 2 +- tests/unit/test_zero_mq_worker.py | 2 +- zero/encoder/protocols.py | 2 +- zero/error.py | 6 +- zero/protocols/zeromq/client.py | 202 ++++-------------- zero/protocols/zeromq/server.py | 2 +- zero/protocols/zeromq/worker.py | 25 ++- zero/rpc/client.py | 4 +- zero/{zero_mq => zeromq_patterns}/__init__.py | 0 zero/{zero_mq => zeromq_patterns}/factory.py | 2 +- zero/{zero_mq => zeromq_patterns}/helpers.py | 2 +- .../{zero_mq => zeromq_patterns}/protocols.py | 14 +- .../queue_device/__init__.py | 0 .../queue_device/broker.py | 0 .../queue_device/client.py | 32 +++ .../queue_device/worker.py | 25 ++- 22 files changed, 216 insertions(+), 186 deletions(-) create mode 100644 examples/basic/schema.py rename zero/{zero_mq => zeromq_patterns}/__init__.py (100%) rename zero/{zero_mq => zeromq_patterns}/factory.py (95%) rename zero/{zero_mq => zeromq_patterns}/helpers.py (90%) rename zero/{zero_mq => zeromq_patterns}/protocols.py (80%) rename zero/{zero_mq => zeromq_patterns}/queue_device/__init__.py (100%) rename zero/{zero_mq => zeromq_patterns}/queue_device/broker.py (100%) rename zero/{zero_mq => zeromq_patterns}/queue_device/client.py (79%) rename zero/{zero_mq => zeromq_patterns}/queue_device/worker.py (58%) diff --git a/examples/basic/client.py b/examples/basic/client.py index 3c28dca..6db2b43 100644 --- a/examples/basic/client.py +++ b/examples/basic/client.py @@ -5,6 +5,8 @@ from zero import AsyncZeroClient from zero.error import ZeroException +from .schema import User + zero_client = AsyncZeroClient("localhost", 5559) @@ -37,6 +39,37 @@ async def two_rets(): print(resp) +async def hello_user(): + resp = await zero_client.call( + "hello_user", + User( + name="John", + age=25, + emails=["hello@hello.com"], + ), + ) + print(resp) + + +async def hello_users(): + resp = await zero_client.call( + "hello_users", + [ + User( + name="John", + age=25, + emails=["hello@hello.com"], + ), + User( + name="Jane", + age=30, + emails=["hello@hello.com"], + ), + ], + ) + print(resp) + + if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(echo()) @@ -44,3 +77,6 @@ async def two_rets(): loop.run_until_complete(sum_list()) loop.run_until_complete(necho()) loop.run_until_complete(two_rets()) + loop.run_until_complete(hello_user()) + loop.run_until_complete(hello_users()) + loop.close() diff --git a/examples/basic/schema.py b/examples/basic/schema.py new file mode 100644 index 0000000..ecb8e66 --- /dev/null +++ b/examples/basic/schema.py @@ -0,0 +1,9 @@ +from typing import List + +import msgspec + + +class User(msgspec.Struct): + name: str + age: int + emails: List[str] diff --git a/examples/basic/server.py b/examples/basic/server.py index 659e7d5..ffc3c36 100644 --- a/examples/basic/server.py +++ b/examples/basic/server.py @@ -5,6 +5,10 @@ from zero import ZeroServer +from .schema import User + +app = ZeroServer(port=5559) + async def echo(msg: str) -> str: return msg @@ -24,12 +28,21 @@ async def sum_list(msg: typing.List[int]) -> int: return sum(msg) -async def two_rets(msg: typing.List) -> typing.Tuple[int, int]: +async def two_rets(msg: str) -> typing.Tuple[int, int]: return 1, 2 +@app.register_rpc +def hello_user(user: User) -> str: + return f"Hello {user.name}! You are {user.age} years old. Your email is {user.emails[0]}!" + + +@app.register_rpc +def hello_users(users: typing.List[User]) -> str: + return f"Hello {', '.join([user.name for user in users])}! Your emails are {', '.join([email for user in users for email in user.emails])}!" + + if __name__ == "__main__": - app = ZeroServer(port=5559) app.register_rpc(echo) app.register_rpc(hello_world) app.register_rpc(decode_jwt) diff --git a/tests/concurrency/rps_async.py b/tests/concurrency/rps_async.py index fcfad47..b8ad5e1 100644 --- a/tests/concurrency/rps_async.py +++ b/tests/concurrency/rps_async.py @@ -11,7 +11,8 @@ async def task(semaphore, items): async with semaphore: try: - res = await async_client.call("sum_async", items) + await async_client.call("sum_async", items) + # res = await async_client.call("sum_async", items) # print(res) except Exception as e: print(e) diff --git a/tests/concurrency/rps_sync.py b/tests/concurrency/rps_sync.py index 454545b..6b4f8b5 100644 --- a/tests/concurrency/rps_sync.py +++ b/tests/concurrency/rps_sync.py @@ -12,7 +12,8 @@ def get_and_sum(msg): - resp = sum_func(msg) + sum_func(msg) + # resp = sum_func(msg) # print(resp) diff --git a/tests/functional/single_server/client_server_test.py b/tests/functional/single_server/client_server_test.py index 1dbf8de..6d79246 100644 --- a/tests/functional/single_server/client_server_test.py +++ b/tests/functional/single_server/client_server_test.py @@ -5,6 +5,7 @@ import zero.error from zero import AsyncZeroClient, ZeroClient +from zero.error import ValidationException from . import server from .server import Message @@ -38,8 +39,17 @@ def test_sum_list(): def test_echo_dict(): zero_client = ZeroClient(server.HOST, server.PORT) - msg = zero_client.call("echo_dict", {"a": "b"}) - assert msg == {"a": "b"} + msg = zero_client.call("echo_dict", {1: "b"}) + assert msg == {1: "b"} + + +def test_echo_dict_validation_error(): + zero_client = ZeroClient(server.HOST, server.PORT) + with pytest.raises(ValidationException): + msg = zero_client.call("echo_dict", {"a": "b"}) + assert msg == { + "__zerror__validation_error": "Expected `int`, got `str` - at `key` in `$`" + } def test_echo_tuple(): diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 90c96b7..f09fd9e 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -8,7 +8,7 @@ from zero import ZeroServer from zero.encoder.protocols import Encoder -from zero.zero_mq.protocols import ZeroMQBroker +from zero.zeromq_patterns.protocols import ZeroMQBroker DEFAULT_PORT = 5559 DEFAULT_HOST = "0.0.0.0" diff --git a/tests/unit/test_zero_mq_worker.py b/tests/unit/test_zero_mq_worker.py index 0a7523b..b5b0af8 100644 --- a/tests/unit/test_zero_mq_worker.py +++ b/tests/unit/test_zero_mq_worker.py @@ -4,7 +4,7 @@ import pytest import zmq -from zero.zero_mq.queue_device.worker import ZeroMQWorker +from zero.zeromq_patterns.queue_device.worker import ZeroMQWorker class TestWorker(unittest.TestCase): diff --git a/zero/encoder/protocols.py b/zero/encoder/protocols.py index b602ee9..59c2b20 100644 --- a/zero/encoder/protocols.py +++ b/zero/encoder/protocols.py @@ -2,7 +2,7 @@ @runtime_checkable -class Encoder(Protocol): +class Encoder(Protocol): # pragma: no cover def encode(self, data: Any) -> bytes: ... diff --git a/zero/error.py b/zero/error.py index 74f949b..5b5b2c7 100644 --- a/zero/error.py +++ b/zero/error.py @@ -19,5 +19,9 @@ class ConnectionException(ZeroException): pass -class RemoteException(Exception): +class RemoteException(ZeroException): + pass + + +class ValidationException(ZeroException): pass diff --git a/zero/protocols/zeromq/client.py b/zero/protocols/zeromq/client.py index 1964b96..a78598d 100644 --- a/zero/protocols/zeromq/client.py +++ b/zero/protocols/zeromq/client.py @@ -1,13 +1,18 @@ import asyncio import logging import threading -from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union from zero import config from zero.encoder import Encoder, get_encoder from zero.error import TimeoutException from zero.utils import util -from zero.zero_mq import AsyncZeroMQClient, ZeroMQClient, get_async_client, get_client +from zero.zeromq_patterns import ( + AsyncZeroMQClient, + ZeroMQClient, + get_async_client, + get_client, +) T = TypeVar("T") @@ -19,35 +24,6 @@ def __init__( default_timeout: int = 2000, encoder: Optional[Encoder] = None, ): - """ - ZeroClient provides the client interface for calling the ZeroServer. - - Zero use tcp protocol for communication. - So a connection needs to be established to make a call. - The connection creation is done lazily. - So the first call will take some time to establish the connection. - If the connection is dropped the client might timeout. - But in the next call the connection will be re-established. - - For different threads/processes, different connections are created. - - Parameters - ---------- - host: str - Host of the ZeroServer. - - port: int - Port of the ZeroServer. - - default_timeout: int - Default timeout for all calls. Default is 2000 ms. - - encoder: Optional[Encoder] - Encoder to encode/decode messages from/to client. - Default is msgspec. - If any other encoder is used, make sure the server should use the same encoder. - Implement custom encoder by inheriting from `zero.encoder.Encoder`. - """ self._address = address self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) @@ -65,47 +41,6 @@ def call( timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: - """ - Call the rpc function resides on the ZeroServer. - - Parameters - ---------- - rpc_func_name: str - Function name should be string. - This funtion should reside on the ZeroServer to get a successful response. - - msg: Union[int, float, str, dict, list, tuple, None] - The only argument of the rpc function. - This should be of the same type as the rpc function's argument. - - timeout: Optional[int] - Timeout for the call. In milliseconds. - Default is 2000 milliseconds. - - return_type: Optional[Type[T]] - The return type of the rpc function. - If return_type is set, the response will be parsed to the return_type. - - Returns - ------- - T - The return value of the rpc function. - If return_type is set, the response will be parsed to the return_type. - - Raises - ------ - TimeoutException - If the call times out or the connection is dropped. - - MethodNotFoundException - If the rpc function is not found on the ZeroServer. - - ConnectionException - If zeromq connection is not established. - Or zeromq cannot send the message to the server. - Or zeromq cannot receive the response from the server. - Mainly represents zmq.error.Again exception. - """ zmqc = self.client_pool.get() _timeout = self._default_timeout if timeout is None else timeout @@ -117,16 +52,28 @@ def _poll_data(): f"Timeout while sending message at {self._address}" ) - resp_id, resp_data = ( - self._encoder.decode(zmqc.recv()) + rcv_data = zmqc.recv() + + # first 32 bytes as response id + resp_id = rcv_data[:32].decode() + + # the rest is response data + resp_data_encoded = rcv_data[32:] + resp_data = ( + self._encoder.decode(resp_data_encoded) if return_type is None - else self._encoder.decode_type(zmqc.recv(), Tuple[str, return_type]) + else self._encoder.decode_type(resp_data_encoded, return_type) ) + return resp_id, resp_data req_id = util.unique_id() - frames = [req_id, rpc_func_name, "" if msg is None else msg] - zmqc.send(self._encoder.encode(frames)) + + # function name exactly 120 bytes + func_name_bytes = rpc_func_name.ljust(120).encode() + + msg_bytes = b"" if msg is None else self._encoder.encode(msg) + zmqc.send(req_id.encode() + func_name_bytes + msg_bytes) resp_id, resp_data = None, None # as the client is synchronous, we know that the response will be available any next poll @@ -149,37 +96,6 @@ def __init__( default_timeout: int = 2000, encoder: Optional[Encoder] = None, ): - """ - AsyncZeroClient provides the asynchronous client interface for calling the ZeroServer. - Python's async/await can be used to make the calls. - Naturally async client is faster. - - Zero use tcp protocol for communication. - So a connection needs to be established to make a call. - The connection creation is done lazily. - So the first call will take some time to establish the connection. - If the connection is dropped the client might timeout. - But in the next call the connection will be re-established. - - For different threads/processes, different connections are created. - - Parameters - ---------- - host: str - Host of the ZeroServer. - - port: int - Port of the ZeroServer. - - default_timeout: int - Default timeout for all calls. Default is 2000 ms. - - encoder: Optional[Encoder] - Encoder to encode/decode messages from/to client. - Default is msgspec. - If any other encoder is used, the server should use the same encoder. - Implement custom encoder by inheriting from `zero.encoder.Encoder`. - """ self._address = address self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) @@ -198,47 +114,6 @@ async def call( timeout: Optional[int] = None, return_type: Optional[Type[T]] = None, ) -> T: - """ - Call the rpc function resides on the ZeroServer. - - Parameters - ---------- - rpc_func_name: str - Function name should be string. - This funtion should reside on the ZeroServer to get a successful response. - - msg: Union[int, float, str, dict, list, tuple, None] - The only argument of the rpc function. - This should be of the same type as the rpc function's argument. - - timeout: Optional[int] - Timeout for the call. In milliseconds. - Default is 2000 milliseconds. - - return_type: Optional[Type[T]] - The return type of the rpc function. - If return_type is set, the response will be parsed to the return_type. - - Returns - ------- - T - The return value of the rpc function. - If return_type is set, the response will be parsed to the return_type. - - Raises - ------ - TimeoutException - If the call times out or the connection is dropped. - - MethodNotFoundException - If the rpc function is not found on the ZeroServer. - - ConnectionException - If zeromq connection is not established. - Or zeromq cannot send the message to the server. - Or zeromq cannot receive the response from the server. - Mainly represents zmq.error.Again exception. - """ zmqc = await self.client_pool.get() _timeout = self._default_timeout if timeout is None else timeout @@ -249,11 +124,16 @@ async def _poll_data(): # if not await zmqc.poll(_timeout): # raise TimeoutException(f"Timeout while sending message at {self._address}") + # first 32 bytes as response id resp = await zmqc.recv() - resp_id, resp_data = ( - self._encoder.decode(resp) + resp_id = resp[:32].decode() + + # the rest is response data + resp_data_encoded = resp[32:] + resp_data = ( + self._encoder.decode(resp_data_encoded) if return_type is None - else self._encoder.decode_type(resp, Tuple[str, return_type]) + else self._encoder.decode_type(resp_data_encoded, return_type) ) self._resp_map[resp_id] = resp_data @@ -261,8 +141,12 @@ async def _poll_data(): # await self.peer1.send(b"") req_id = util.unique_id() - frames = [req_id, rpc_func_name, "" if msg is None else msg] - await zmqc.send(self._encoder.encode(frames)) + + # function name exactly 120 bytes + func_name_bytes = rpc_func_name.ljust(120).encode() + + msg_bytes = b"" if msg is None else self._encoder.encode(msg) + await zmqc.send(req_id.encode() + func_name_bytes + msg_bytes) # every request poll the data, so whenever a response comes, it will be stored in __resps # dont need to poll again in the while loop @@ -316,9 +200,8 @@ def get(self) -> ZeroMQClient: return self._pool[thread_id] def _try_connect_ping(self, client: ZeroMQClient): - frames = [util.unique_id(), "connect", ""] - client.send(self._encoder.encode(frames)) - self._encoder.decode(client.recv()) + client.send(util.unique_id().encode() + b"connect" + b"") + client.recv() logging.info("Connected to server at %s", self._address) def close(self): @@ -357,9 +240,8 @@ async def get(self) -> AsyncZeroMQClient: return self._pool[thread_id] async def _try_connect_ping(self, client: AsyncZeroMQClient): - frames = [util.unique_id(), "connect", ""] - await client.send(self._encoder.encode(frames)) - self._encoder.decode(await client.recv()) + await client.send(util.unique_id().encode() + b"connect" + b"") + await client.recv() logging.info("Connected to server at %s", self._address) def close(self): diff --git a/zero/protocols/zeromq/server.py b/zero/protocols/zeromq/server.py index 41a01f0..766a506 100644 --- a/zero/protocols/zeromq/server.py +++ b/zero/protocols/zeromq/server.py @@ -11,7 +11,7 @@ from zero import config from zero.encoder import Encoder from zero.utils import util -from zero.zero_mq import ZeroMQBroker, get_broker +from zero.zeromq_patterns import ZeroMQBroker, get_broker from .worker import _Worker diff --git a/zero/protocols/zeromq/worker.py b/zero/protocols/zeromq/worker.py index feec284..4fb6139 100644 --- a/zero/protocols/zeromq/worker.py +++ b/zero/protocols/zeromq/worker.py @@ -3,12 +3,14 @@ import time from typing import Optional +from msgspec import ValidationError + from zero import config from zero.codegen.codegen import CodeGen from zero.encoder.protocols import Encoder from zero.error import SERVER_PROCESSING_ERROR from zero.utils.async_to_sync import async_to_sync -from zero.zero_mq.factory import get_worker +from zero.zeromq_patterns.factory import get_worker class _Worker: @@ -35,18 +37,29 @@ def __init__( ) def start_dealer_worker(self, worker_id): - def process_message(data: bytes) -> Optional[bytes]: + def process_message(func_name_encoded: bytes, data: bytes) -> Optional[bytes]: try: - decoded = self._encoder.decode(data) - req_id, func_name, msg = decoded + func_name = func_name_encoded.decode() + input_type = self._rpc_input_type_map.get(func_name) + + msg = "" + if data: + if input_type: + msg = self._encoder.decode_type(data, input_type) + else: + msg = self._encoder.decode(data) + response = self.handle_msg(func_name, msg) - return self._encoder.encode([req_id, response]) + return self._encoder.encode(response) + except ValidationError as exc: + logging.exception(exc) + return self._encoder.encode({"__zerror__validation_error": str(exc)}) except ( Exception ) as inner_exc: # pragma: no cover pylint: disable=broad-except logging.exception(inner_exc) return self._encoder.encode( - ["", {"__zerror__server_exception": SERVER_PROCESSING_ERROR}] + {"__zerror__server_exception": SERVER_PROCESSING_ERROR} ) worker = get_worker(config.ZEROMQ_PATTERN, worker_id) diff --git a/zero/rpc/client.py b/zero/rpc/client.py index ea88062..77d54e8 100644 --- a/zero/rpc/client.py +++ b/zero/rpc/client.py @@ -2,7 +2,7 @@ from zero import config from zero.encoder import Encoder, get_encoder -from zero.error import MethodNotFoundException, RemoteException +from zero.error import MethodNotFoundException, RemoteException, ValidationException if TYPE_CHECKING: from zero.rpc.protocols import AsyncZeroClientProtocol, ZeroClientProtocol @@ -260,3 +260,5 @@ def check_response(resp_data): raise MethodNotFoundException(exc) if exc := resp_data.get("__zerror__server_exception"): raise RemoteException(exc) + if exc := resp_data.get("__zerror__validation_error"): + raise ValidationException(exc) diff --git a/zero/zero_mq/__init__.py b/zero/zeromq_patterns/__init__.py similarity index 100% rename from zero/zero_mq/__init__.py rename to zero/zeromq_patterns/__init__.py diff --git a/zero/zero_mq/factory.py b/zero/zeromq_patterns/factory.py similarity index 95% rename from zero/zero_mq/factory.py rename to zero/zeromq_patterns/factory.py index d1da02c..f59d02c 100644 --- a/zero/zero_mq/factory.py +++ b/zero/zeromq_patterns/factory.py @@ -1,4 +1,4 @@ -from zero.zero_mq import queue_device +from zero.zeromq_patterns import queue_device from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker diff --git a/zero/zero_mq/helpers.py b/zero/zeromq_patterns/helpers.py similarity index 90% rename from zero/zero_mq/helpers.py rename to zero/zeromq_patterns/helpers.py index 20fb695..404c66b 100644 --- a/zero/zero_mq/helpers.py +++ b/zero/zeromq_patterns/helpers.py @@ -8,7 +8,7 @@ def zpipe_async( ctx: zmq.asyncio.Context, timeout: int = 1000 -) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: +) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: # pragma: no cover """ Build inproc pipe for talking to threads diff --git a/zero/zero_mq/protocols.py b/zero/zeromq_patterns/protocols.py similarity index 80% rename from zero/zero_mq/protocols.py rename to zero/zeromq_patterns/protocols.py index 92e9a7b..3af0d1c 100644 --- a/zero/zero_mq/protocols.py +++ b/zero/zeromq_patterns/protocols.py @@ -19,12 +19,18 @@ def close(self) -> None: def send(self, message: bytes) -> None: ... + def send_multipart(self, message: list) -> None: + ... + def poll(self, timeout: int) -> bool: ... def recv(self) -> bytes: ... + def recv_multipart(self) -> list: + ... + def request(self, message: bytes) -> Any: ... @@ -44,12 +50,18 @@ def close(self) -> None: async def send(self, message: bytes) -> None: ... + async def send_multipart(self, message: list) -> None: + ... + async def poll(self, timeout: int) -> bool: ... async def recv(self) -> bytes: ... + async def recv_multipart(self) -> list: + ... + async def request(self, message: bytes) -> Any: ... @@ -66,7 +78,7 @@ def close(self) -> None: @runtime_checkable class ZeroMQWorker(Protocol): # pragma: no cover def listen( - self, address: str, msg_handler: Callable[[bytes], Optional[bytes]] + self, address: str, msg_handler: Callable[[bytes, bytes], Optional[bytes]] ) -> None: ... diff --git a/zero/zero_mq/queue_device/__init__.py b/zero/zeromq_patterns/queue_device/__init__.py similarity index 100% rename from zero/zero_mq/queue_device/__init__.py rename to zero/zeromq_patterns/queue_device/__init__.py diff --git a/zero/zero_mq/queue_device/broker.py b/zero/zeromq_patterns/queue_device/broker.py similarity index 100% rename from zero/zero_mq/queue_device/broker.py rename to zero/zeromq_patterns/queue_device/broker.py diff --git a/zero/zero_mq/queue_device/client.py b/zero/zeromq_patterns/queue_device/client.py similarity index 79% rename from zero/zero_mq/queue_device/client.py rename to zero/zeromq_patterns/queue_device/client.py index c768651..997ce13 100644 --- a/zero/zero_mq/queue_device/client.py +++ b/zero/zeromq_patterns/queue_device/client.py @@ -42,6 +42,14 @@ def send(self, message: bytes) -> None: f"Connection error for send at {self._address}" ) from exc + def send_multipart(self, message: list) -> None: + try: + self.socket.send_multipart(message, copy=False) + except zmqerr.Again as exc: + raise ConnectionException( + f"Connection error for send at {self._address}" + ) from exc + def poll(self, timeout: int) -> bool: socks = dict(self.poller.poll(timeout)) return self.socket in socks @@ -54,6 +62,14 @@ def recv(self) -> bytes: f"Connection error for recv at {self._address}" ) from exc + def recv_multipart(self) -> list: + try: + return self.socket.recv_multipart() + except zmqerr.Again as exc: + raise ConnectionException( + f"Connection error for recv at {self._address}" + ) from exc + def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: try: self.send(message) @@ -103,6 +119,14 @@ async def send(self, message: bytes) -> None: f"Connection error for send at {self._address}" ) from exc + async def send_multipart(self, message: list) -> None: + try: + await self.socket.send_multipart(message, copy=False) + except zmqerr.Again as exc: + raise ConnectionException( + f"Connection error for send at {self._address}" + ) from exc + async def poll(self, timeout: int) -> bool: socks = dict(await self.poller.poll(timeout)) return self.socket in socks @@ -115,6 +139,14 @@ async def recv(self) -> bytes: f"Connection error for recv at {self._address}" ) from exc + async def recv_multipart(self) -> list: + try: + return await self.socket.recv_multipart() + except zmqerr.Again as exc: + raise ConnectionException( + f"Connection error for recv at {self._address}" + ) from exc + async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: try: await self.send(message) diff --git a/zero/zero_mq/queue_device/worker.py b/zero/zeromq_patterns/queue_device/worker.py similarity index 58% rename from zero/zero_mq/queue_device/worker.py rename to zero/zeromq_patterns/queue_device/worker.py index 5377adc..9d50aed 100644 --- a/zero/zero_mq/queue_device/worker.py +++ b/zero/zeromq_patterns/queue_device/worker.py @@ -17,7 +17,7 @@ def __init__(self, worker_id: int): self.socket.setsockopt(zmq.SNDTIMEO, 2000) def listen( - self, address: str, msg_handler: Callable[[bytes], Optional[bytes]] + self, address: str, msg_handler: Callable[[bytes, bytes], Optional[bytes]] ) -> None: self.socket.connect(address) logging.info("Starting worker %d", self.worker_id) @@ -28,17 +28,32 @@ def listen( except zmq.error.Again: continue - def _recv_and_process(self, msg_handler: Callable[[bytes], Optional[bytes]]): + def _recv_and_process(self, msg_handler: Callable[[bytes, bytes], Optional[bytes]]): + # multipart because first frame is ident, set by the broker frames = self.socket.recv_multipart() if len(frames) != 2: logging.error("invalid message received: %s", frames) return - ident, message = frames - response = msg_handler(message) + # ident is set by the broker, because it is a DEALER socket + # so the broker knows who to send the response to + ident, data = frames + + # first 32 bytes is request id + req_id = data[:32] + + # then 120 bytes is function name + func_name = data[32:152].strip() + + # the rest is message + message = data[152:] + + response = msg_handler(func_name, message) # TODO send is slow, need to find a way to make it faster - self.socket.send_multipart([ident, response], zmq.NOBLOCK) + self.socket.send_multipart( + [ident, req_id + response if response else b""], copy=False + ) def close(self) -> None: self.socket.close()