From fa78d8210cdce1beb5a5553d20abc16e7744e555 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Fri, 25 Oct 2024 13:26:52 +0200 Subject: [PATCH] Rewrite sync Assembler to improve performance. Previously, a latch was used to synchronize the user thread reading messages and the background thread reading from the network. This required two thread switches per message. Now, the background thread writes messages to queue, from which the user thread reads. This allows passing several frames at each thread switch, reducing the overhead. With this server code:: async def test(websocket): for i in range(int(await websocket.recv())): await websocket.send(f"{{\"iteration\": {i}}}") and this client code:: with connect("ws://localhost:8765", compression=None) as websocket: websocket.send("1_000_000") for message in websocket: pass an unscientific benchmark (running it on my laptop) shows a 2.5x speedup, going from 11 seconds to 4.4 seconds. Setting a very large recv_bufsize and max_size doesn't yield significant further improvement. The new implementation mirrors the asyncio implementation and gains the option to prevent or force decoding of frames. Refs #1376. --- docs/project/changelog.rst | 16 +- src/websockets/asyncio/client.py | 2 +- src/websockets/asyncio/messages.py | 60 +++-- src/websockets/asyncio/server.py | 2 +- src/websockets/sync/client.py | 12 +- src/websockets/sync/connection.py | 79 ++++-- src/websockets/sync/messages.py | 333 ++++++++++++------------ src/websockets/sync/server.py | 12 +- tests/asyncio/test_connection.py | 19 +- tests/asyncio/test_messages.py | 12 +- tests/sync/test_connection.py | 106 +++++--- tests/sync/test_messages.py | 400 ++++++++++++++--------------- 12 files changed, 586 insertions(+), 467 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f5b4812b..4e5cdc24 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -70,10 +70,21 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +New features +............ + +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`threading` implementation; also binary frames as :class:`str`. + +* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` + and :mod:`threading` implementations, as well as :class:`str` a binary frame. + Improvements ............ -* Sending or receiving large compressed frames is now faster. +* Sending or receiving large compressed messages is now faster. + +* The :mod:`threading` implementation receives messages faster. .. _13.1: @@ -198,6 +209,9 @@ New features * Validated compatibility with Python 3.12 and 3.13. +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`asyncio` implementation; also binary frames as :class:`str`. + * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 23b1a348..0c8bedc5 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -45,7 +45,7 @@ class ClientConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`connect`. + and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e3ec5062..d0738b55 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None: self.queue.extend(items) def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) # Clear the queue to avoid storing unnecessary data in memory. @@ -89,7 +90,7 @@ def __init__( # pragma: no cover pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: - # Queue of incoming messages. Each item is a queue of frames. + # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # First frame + # Locking with get_in_progress prevents concurrent execution until + # get() fetches a complete message or is cancelled. + try: + # First frame frame = await self.frames.get() - except asyncio.CancelledError: - self.get_in_progress = False - raise - self.maybe_resume() - assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY - if decode is None: - decode = frame.opcode is OP_TEXT - frames = [frame] - - # Following frames, for fragmented messages - while not frame.fin: - try: - frame = await self.frames.get() - except asyncio.CancelledError: - # Put frames already received back into the queue - # so that future calls to get() can return them. - self.frames.reset(frames) - self.get_in_progress = False - raise self.maybe_resume() - assert frame.opcode is OP_CONT - frames.append(frame) - - self.get_in_progress = False + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + # First frame try: frame = await self.frames.get() @@ -233,6 +238,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: # here will leave the assembler in a stuck state. Future calls to # get() or get_iter() will raise ConcurrencyError. frame = await self.frames.get() + self.maybe_resume() assert frame.opcode is OP_CONT if decode: diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index e11dd91f..a6ae5996 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -55,7 +55,7 @@ class ServerConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`serve`. + and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5e1ba6d8..42daa32e 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,10 +40,12 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`connect`. + Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -53,6 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -60,6 +63,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) def handshake( @@ -135,6 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -183,6 +188,10 @@ def connect( :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -287,6 +296,7 @@ def connect( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: if sock is not None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3f4cac09..3ab9f493 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,10 +49,14 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -76,8 +80,15 @@ def __init__( # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + # Assembler turning frames into messages and serializing reads. - self.recv_messages = Assembler() + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) # Whether we are busy sending a fragmented message. self.send_in_progress = False @@ -88,6 +99,10 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. @@ -97,10 +112,6 @@ def __init__( ) self.recv_events_thread.start() - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - # Public attributes @property @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: float | None = None) -> Data: + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -191,6 +202,11 @@ def recv(self, timeout: float | None = None) -> Data: If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -198,6 +214,16 @@ def recv(self, timeout: float | None = None) -> Data: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -205,7 +231,7 @@ def recv(self, timeout: float | None = None) -> Data: """ try: - return self.recv_messages.get(timeout) + return self.recv_messages.get(timeout, decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -216,16 +242,23 @@ def recv(self, timeout: float | None = None) -> Data: "is already running recv or recv_streaming" ) from None - def recv_streaming(self) -> Iterator[Data]: + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. - If the message is fragmented, yield each fragment as it is received. - The iterator must be fully consumed, or else the connection will become + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -233,6 +266,15 @@ def recv_streaming(self) -> Iterator[Data]: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -240,7 +282,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + yield from self.recv_messages.get_iter(decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -571,8 +613,9 @@ def recv_events(self) -> None: try: while True: try: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + with self.recv_flow_control: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -622,13 +665,9 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - try: - for event in events: - # This may raise EOFError if the closing handshake - # times out while a message is waiting to be read. - self.process_event(event) - except EOFError: - break + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 997fa98d..983b114d 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,12 +3,12 @@ import codecs import queue import threading -from collections.abc import Iterator -from typing import cast +from typing import Any, Callable, Iterable, Iterator from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data +from .utils import Deadline __all__ = ["Assembler"] @@ -20,47 +20,83 @@ class Assembler: """ Assemble messages from frames. + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + """ - def __init__(self) -> None: + def __init__( + self, + high: int = 16, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to synchronize the production of - # frames and the consumption of messages (or frames) without a buffer. - # This design requires a switch between the library thread and the user - # thread for each message; that shouldn't be a performance bottleneck. - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = threading.Event() - # get() sets this event to let put() that the message was fetched. - self.message_fetched = threading.Event() + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag prevents concurrent calls to put() by library code. - self.put_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder: codecs.IncrementalDecoder | None = None - - # Buffer of frames belonging to the same message. - self.chunks: list[Data] = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the message, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: float | None = None) -> Data: + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + try: + frame = self.frames.get(timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get_nowait()) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -73,11 +109,14 @@ def get(self, timeout: float | None = None) -> Data: Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -89,40 +128,45 @@ def get(self, timeout: float | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") + # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - completed = self.message_complete.wait(timeout) + try: + deadline = Deadline(timeout) + + # First frame + frame = self.get_next_frame(deadline.timeout()) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = self.get_next_frame(deadline.timeout()) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) - with self.mutex: + finally: self.get_in_progress = False - # Waiting for a complete message timed out. - if not completed: - raise TimeoutError(f"timed out in {timeout:.1f}s") - - # get() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data - self.chunks = [] - assert self.chunks_queue is None - - assert not self.message_fetched.is_set() - self.message_fetched.set() - - return message - - def get_iter(self) -> Iterator[Data]: + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. @@ -135,10 +179,15 @@ def get_iter(self) -> Iterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ with self.mutex: @@ -148,116 +197,81 @@ def get_iter(self) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - chunks = self.chunks - self.chunks = [] - self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.10. - "queue.SimpleQueue[Data | None]", - queue.SimpleQueue(), - ) - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - self.chunks_queue.put(None) - + # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress ensures only one thread can get here. - chunk: Data | None - for chunk in chunks: - yield chunk - while (chunk := self.chunks_queue.get()) is not None: - yield chunk + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. - with self.mutex: - self.get_in_progress = False + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. - # get_iter() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - assert self.chunks == [] - self.chunks_queue = None + # First frame + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data - assert not self.message_fetched.is_set() - self.message_fetched.set() + self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, which can be achieved by calling :meth:`get` or - by fully consuming the return value of :meth:`get_iter`. - - :meth:`put` assumes that the stream of frames respects the protocol. If - it doesn't, the behavior is undefined. - Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") - if self.put_in_progress: - raise ConcurrencyError("put is already running") - - if frame.opcode is OP_TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is OP_BINARY: - self.decoder = None - else: - assert frame.opcode is OP_CONT - - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - - if self.chunks_queue is None: - self.chunks.append(data) - else: - self.chunks_queue.put(data) - - if not frame.fin: - return - - # Message is complete. Wait until it's fetched to return. - - assert not self.message_complete.is_set() - self.message_complete.set() - - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - assert not self.message_fetched.is_set() - - self.put_in_progress = True - - # Release the lock to allow get() to run and eventually set the event. - # Locking with put_in_progress ensures only one coroutine can get here. - self.message_fetched.wait() - - with self.mutex: - self.put_in_progress = False - - # put() was unblocked by close() rather than get() or get_iter(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_fetched.is_set() - self.message_fetched.clear() - - self.decoder = None + self.frames.put(frame) + self.maybe_pause() + + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. + + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + assert self.mutex.locked() + # Check for "> high" to support high = 0 + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + assert self.mutex.locked() + # Check for "<= low" to support low = 0 + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() def close(self) -> None: """ @@ -273,12 +287,5 @@ def close(self) -> None: self.closed = True - # Unblock get or get_iter. - if self.get_in_progress: - self.message_complete.set() - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - # Unblock put(). - if self.put_in_progress: - self.message_fetched.set() + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 464c4a17..94f76b65 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -51,10 +51,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`serve`. + Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -64,6 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -71,6 +74,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) self.username: str # see basic_auth() @@ -349,6 +353,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -427,6 +432,10 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -548,6 +557,7 @@ def protocol_select_subprotocol( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: sock.close() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 563cf2b1..12e2bd5f 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,8 +793,8 @@ async def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) async def test_close_does_not_wait_for_recv(self): - # The asyncio implementation has a buffer for incoming messages. Closing - # the connection discards buffered messages. This is allowed by the RFC: + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: # > However, there is no guarantee that the endpoint that has already # > sent a Close frame will continue to process data. await self.remote_connection.send("😀") @@ -1075,7 +1075,10 @@ async def test_max_queue(self): async def test_max_queue_tuple(self): """max_queue parameter configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + connection = Connection( + Protocol(self.LOCAL), + max_queue=(4, 2), + ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) @@ -1083,14 +1086,20 @@ async def test_max_queue_tuple(self): async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=4096) + connection = Connection( + Protocol(self.LOCAL), + write_limit=4096, + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + connection = Connection( + Protocol(self.LOCAL), + write_limit=(4096, 2048), + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index d2cf25c9..2ff929d3 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -350,7 +350,7 @@ async def test_cancel_get_iter_before_first_frame(self): self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): - """get cannot be canceled after reading the first frame.""" + """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -429,7 +429,7 @@ async def test_get_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" @@ -437,7 +437,7 @@ async def test_get_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" @@ -445,7 +445,7 @@ async def test_get_iter_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" @@ -453,7 +453,7 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits @@ -463,7 +463,7 @@ async def test_set_high_water_mark(self): self.assertEqual(assembler.high, 10) async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water mark and low-water mark.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 87333fd3..db1cc8e9 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -154,6 +154,16 @@ def test_recv_binary(self): self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(decode=False), "😀".encode()) + + def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual(self.connection.recv(decode=True), "😀") + def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -228,6 +238,22 @@ def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual( + list(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -499,28 +525,17 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_waits_for_recv(self): - # The sync implementation doesn't have a buffer for incoming messsages. - # It requires reading incoming frames until the close frame is reached. - # This behavior — close() blocks until recv() is called — is less than - # ideal and inconsistent with the asyncio implementation. + def test_close_does_not_wait_for_recv(self): + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. self.remote_connection.send("😀") + self.connection.close() close_thread = threading.Thread(target=self.connection.close) close_thread.start() - # Let close() initiate the closing handshake and send a close frame. - time.sleep(MS) - self.assertTrue(close_thread.is_alive()) - - # Connection isn't closed yet. - self.connection.recv() - - # Let close() receive a close frame and finish the closing handshake. - time.sleep(MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -528,24 +543,6 @@ def test_close_waits_for_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) - def test_close_timeout_waiting_for_recv(self): - self.remote_connection.send("😀") - - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() time out during the closing handshake. - time.sleep(3 * MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") - self.assertIsInstance(exc.__cause__, TimeoutError) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -724,6 +721,45 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test parameters. + + def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + # Test attributes. def test_id(self): diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d44b39b8..02513894 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,4 +1,6 @@ import time +import unittest +import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -9,66 +11,23 @@ class AssemblerTests(ThreadTestCase): - """ - Tests in this class interact a lot with hidden synchronization mechanisms: - - - get() / get_iter() and put() must run in separate threads when a final - frame is set because put() waits for get() / get_iter() to fetch the - message before returning. - - - run_in_thread() lets its target run before yielding back control on entry, - which guarantees the intended execution order of test cases. - - - run_in_thread() waits for its target to finish running before yielding - back control on exit, which allows making assertions immediately. - - - When the main thread performs actions that let another thread progress, it - must wait before making assertions, to avoid depending on scheduling. - - """ - def setUp(self): - self.assembler = Assembler() - - def tearDown(self): - """ - Check that the assembler goes back to its default state after each test. - - This removes the need for testing various sequences. - - """ - self.assertFalse(self.assembler.mutex.locked()) - self.assertFalse(self.assembler.get_in_progress) - self.assertFalse(self.assembler.put_in_progress) - if not self.assembler.closed: - self.assertFalse(self.assembler.message_complete.is_set()) - self.assertFalse(self.assembler.message_fetched.is_set()) - self.assertIsNone(self.assembler.decoder) - self.assertEqual(self.assembler.chunks, []) - self.assertIsNone(self.assembler.chunks_queue) + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): @@ -99,112 +58,145 @@ def getter(): def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = self.assembler.get() self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_being_received(self): - """get reassembles a fragmented text message that is partially received.""" + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_being_received(self): - """get reassembles a fragmented binary message that is partially received.""" + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_not_received_yet(self): - """get reassembles a fragmented text message when it is received.""" + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_not_received_yet(self): - """get reassembles a fragmented binary message when it is received.""" + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - # Test get_iter + def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + def test_get_timeout_before_first_frame(self): + """get times out before reading the first frame.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_text_message_already_received(self): - """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_timeout_after_first_frame(self): + """get times out after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(fragments, ["café"]) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_binary_message_already_received(self): - """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + # Test get_iter + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): @@ -212,6 +204,7 @@ def test_get_iter_text_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -225,6 +218,7 @@ def test_get_iter_binary_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -235,121 +229,112 @@ def getter(): def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") - self.assertEqual(fragments, [b"t", b"e", b"a"]) + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - self.assertEqual(fragments, ["ca", "f", "é"]) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) - - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - def test_get_iter_fragmented_text_message_not_received_yet(self): - """get_iter yields a fragmented text message when it is received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") + + def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) + def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) - self.assertEqual(fragments, ["ca", "f", "é"]) + def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - def test_get_iter_fragmented_binary_message_not_received_yet(self): - """get_iter yields a fragmented binary message when it is received.""" - fragments = [] + iterator = self.assembler.get_iter() - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - # Test timeouts + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() - def test_get_with_timeout_completes(self): - """get returns a message when it is received before the timeout.""" + # Test put - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get(MS) + def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() - self.assertEqual(message, "café") + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() - def test_get_with_timeout_times_out(self): - """get raises TimeoutError when no message is received before the timeout.""" - with self.assertRaises(TimeoutError): - self.assembler.get(MS) + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() # Test termination @@ -373,18 +358,8 @@ def closer(): with self.run_in_thread(closer): with self.assertRaises(EOFError): - list(self.assembler.get_iter()) - - def test_put_fails_when_interrupted_by_close(self): - """put raises EOFError when close is called.""" - - def closer(): - time.sleep(2 * MS) - self.assembler.close() - - with self.run_in_thread(closer): - with self.assertRaises(EOFError): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" @@ -396,7 +371,8 @@ def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): - list(self.assembler.get_iter()) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_put_fails_after_close(self): """put raises EOFError after close is called.""" @@ -439,13 +415,25 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently.""" + # Test setting limits - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - with self.run_in_thread(putter): - with self.assertRaises(ConcurrencyError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread + def test_set_high_and_low_water_mark(self): + """high sets the high-water mark and low-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5)