diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index f5b4812b..41067123 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -70,10 +70,21 @@ Backwards-incompatible changes If you wrote an :class:`extension ` that relies on methods not provided by these new types, you may need to update your code. +New features +............ + +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`threading` implementation; also binary frames as :class:`str`. + +* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio` + and :mod:`threading` implementations, as well as :class:`str` a binary frame. + Improvements ............ -* Sending or receiving large compressed frames is now faster. +* The :mod:`threading` implementation receives messages faster. + +* Sending or receiving large compressed messages is now faster. .. _13.1: @@ -198,6 +209,9 @@ New features * Validated compatibility with Python 3.12 and 3.13. +* Added an option to receive text frames as :class:`bytes`, without decoding, + in the :mod:`asyncio` implementation; also binary frames as :class:`str`. + * Added :doc:`environment variables <../reference/variables>` to configure debug logs, the ``Server`` and ``User-Agent`` headers, as well as security limits. diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 23b1a348..0c8bedc5 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -45,7 +45,7 @@ class ClientConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`connect`. + and ``write_limit`` arguments have the same meaning as in :func:`connect`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e3ec5062..09be22ba 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None: self.queue.extend(items) def abort(self) -> None: + """Close the queue, raising EOFError in get() if necessary.""" if self.get_waiter is not None and not self.get_waiter.done(): self.get_waiter.set_exception(EOFError("stream of frames ended")) # Clear the queue to avoid storing unnecessary data in memory. @@ -89,7 +90,7 @@ def __init__( # pragma: no cover pause: Callable[[], Any] = lambda: None, resume: Callable[[], Any] = lambda: None, ) -> None: - # Queue of incoming messages. Each item is a queue of frames. + # Queue of incoming frames. self.frames: SimpleQueue[Frame] = SimpleQueue() # We cannot put a hard limit on the size of the queue because a single @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # First frame + # Locking with get_in_progress prevents concurrent execution until + # get() fetches a complete message or is cancelled. + try: + # First frame frame = await self.frames.get() - except asyncio.CancelledError: - self.get_in_progress = False - raise - self.maybe_resume() - assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY - if decode is None: - decode = frame.opcode is OP_TEXT - frames = [frame] - - # Following frames, for fragmented messages - while not frame.fin: - try: - frame = await self.frames.get() - except asyncio.CancelledError: - # Put frames already received back into the queue - # so that future calls to get() can return them. - self.frames.reset(frames) - self.get_in_progress = False - raise self.maybe_resume() - assert frame.opcode is OP_CONT - frames.append(frame) - - self.get_in_progress = False + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.frames.get() + except asyncio.CancelledError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.frames.reset(frames) + raise + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False data = b"".join(frame.data for frame in frames) if decode: @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + # First frame try: frame = await self.frames.get() diff --git a/src/websockets/asyncio/server.py b/src/websockets/asyncio/server.py index e11dd91f..a6ae5996 100644 --- a/src/websockets/asyncio/server.py +++ b/src/websockets/asyncio/server.py @@ -55,7 +55,7 @@ class ServerConnection(Connection): closed with any other code. The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``, - and ``write_limit`` arguments the same meaning as in :func:`serve`. + and ``write_limit`` arguments have the same meaning as in :func:`serve`. Args: protocol: Sans-I/O connection. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 5e1ba6d8..42daa32e 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -40,10 +40,12 @@ class ClientConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`connect`. + Args: socket: Socket connected to a WebSocket server. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -53,6 +55,7 @@ def __init__( protocol: ClientProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ClientProtocol self.response_rcvd = threading.Event() @@ -60,6 +63,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) def handshake( @@ -135,6 +139,7 @@ def connect( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -183,6 +188,10 @@ def connect( :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this client. It defaults to ``logging.getLogger("websockets.client")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -287,6 +296,7 @@ def connect( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: if sock is not None: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 3f4cac09..3ab9f493 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -49,10 +49,14 @@ def __init__( protocol: Protocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.socket = socket self.protocol = protocol self.close_timeout = close_timeout + if isinstance(max_queue, int): + max_queue = (max_queue, None) + self.max_queue = max_queue # Inject reference to this instance in the protocol's logger. self.protocol.logger = logging.LoggerAdapter( @@ -76,8 +80,15 @@ def __init__( # Mutex serializing interactions with the protocol. self.protocol_mutex = threading.Lock() + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = threading.Lock() + # Assembler turning frames into messages and serializing reads. - self.recv_messages = Assembler() + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire, + resume=self.recv_flow_control.release, + ) # Whether we are busy sending a fragmented message. self.send_in_progress = False @@ -88,6 +99,10 @@ def __init__( # Mapping of ping IDs to pong waiters, in chronological order. self.ping_waiters: dict[bytes, threading.Event] = {} + # Exception raised in recv_events, to be chained to ConnectionClosed + # in the user thread in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + # Receiving events from the socket. This thread is marked as daemon to # allow creating a connection in a non-daemon thread and using it in a # daemon thread. This mustn't prevent the interpreter from exiting. @@ -97,10 +112,6 @@ def __init__( ) self.recv_events_thread.start() - # Exception raised in recv_events, to be chained to ConnectionClosed - # in the user thread in order to show why the TCP connection dropped. - self.recv_exc: BaseException | None = None - # Public attributes @property @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]: except ConnectionClosedOK: return - def recv(self, timeout: float | None = None) -> Data: + def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Receive the next message. @@ -191,6 +202,11 @@ def recv(self, timeout: float | None = None) -> Data: If the message is fragmented, wait until all fragments are received, reassemble them, and return the whole message. + Args: + timeout: Timeout for receiving a message in seconds. + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: A string (:class:`str`) for a Text_ frame or a bytestring (:class:`bytes`) for a Binary_ frame. @@ -198,6 +214,16 @@ def recv(self, timeout: float | None = None) -> Data: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -205,7 +231,7 @@ def recv(self, timeout: float | None = None) -> Data: """ try: - return self.recv_messages.get(timeout) + return self.recv_messages.get(timeout, decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -216,16 +242,23 @@ def recv(self, timeout: float | None = None) -> Data: "is already running recv or recv_streaming" ) from None - def recv_streaming(self) -> Iterator[Data]: + def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: """ Receive the next message frame by frame. - If the message is fragmented, yield each fragment as it is received. - The iterator must be fully consumed, or else the connection will become + This method is designed for receiving fragmented messages. It returns an + iterator that yields each fragment as it is received. This iterator must + be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection unusable. :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + Returns: An iterator of strings (:class:`str`) for a Text_ frame or bytestrings (:class:`bytes`) for a Binary_ frame. @@ -233,6 +266,15 @@ def recv_streaming(self) -> Iterator[Data]: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + Raises: ConnectionClosed: When the connection is closed. ConcurrencyError: If two threads call :meth:`recv` or @@ -240,7 +282,7 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + yield from self.recv_messages.get_iter(decode) except EOFError: # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() @@ -571,8 +613,9 @@ def recv_events(self) -> None: try: while True: try: - if self.close_deadline is not None: - self.socket.settimeout(self.close_deadline.timeout()) + with self.recv_flow_control: + if self.close_deadline is not None: + self.socket.settimeout(self.close_deadline.timeout()) data = self.socket.recv(self.recv_bufsize) except Exception as exc: if self.debug: @@ -622,13 +665,9 @@ def recv_events(self) -> None: # Given that automatic responses write small amounts of data, # this should be uncommon, so we don't handle the edge case. - try: - for event in events: - # This may raise EOFError if the closing handshake - # times out while a message is waiting to be read. - self.process_event(event) - except EOFError: - break + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 997fa98d..983b114d 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,12 +3,12 @@ import codecs import queue import threading -from collections.abc import Iterator -from typing import cast +from typing import Any, Callable, Iterable, Iterator from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame from ..typing import Data +from .utils import Deadline __all__ = ["Assembler"] @@ -20,47 +20,83 @@ class Assembler: """ Assemble messages from frames. + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + """ - def __init__(self) -> None: + def __init__( + self, + high: int = 16, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: # Serialize reads and writes -- except for reads via synchronization # primitives provided by the threading and queue modules. self.mutex = threading.Lock() - # We create a latch with two events to synchronize the production of - # frames and the consumption of messages (or frames) without a buffer. - # This design requires a switch between the library thread and the user - # thread for each message; that shouldn't be a performance bottleneck. - - # put() sets this event to tell get() that a message can be fetched. - self.message_complete = threading.Event() - # get() sets this event to let put() that the message was fetched. - self.message_fetched = threading.Event() + # Queue of incoming frames. + self.frames: queue.SimpleQueue[Frame | None] = queue.SimpleQueue() + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if low is None: + low = high // 4 + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag prevents concurrent calls to put() by library code. - self.put_in_progress = False - - # Decoder for text frames, None for binary frames. - self.decoder: codecs.IncrementalDecoder | None = None - - # Buffer of frames belonging to the same message. - self.chunks: list[Data] = [] - - # When switching from "buffering" to "streaming", we use a thread-safe - # queue for transferring frames from the writing thread (library code) - # to the reading thread (user code). We're buffering when chunks_queue - # is None and streaming when it's a SimpleQueue. None is a sentinel - # value marking the end of the message, superseding message_complete. - - # Stream data from frames belonging to the same message. - self.chunks_queue: queue.SimpleQueue[Data | None] | None = None # This flag marks the end of the connection. self.closed = False - def get(self, timeout: float | None = None) -> Data: + def get_next_frame(self, timeout: float | None = None) -> Frame: + # Helper to factor out the logic for getting the next frame from the + # queue, while handling timeouts and reaching the end of the stream. + try: + frame = self.frames.get(timeout=timeout) + except queue.Empty: + raise TimeoutError(f"timed out in {timeout:.1f}s") from None + if frame is None: + raise EOFError("stream of frames ended") + return frame + + def reset_queue(self, frames: Iterable[Frame]) -> None: + # Helper to put frames back into the queue after they were fetched. + # This happens only when the queue is empty. However, by the time + # we acquire self.mutex, put() may have added items in the queue. + # Therefore, we must handle the case where the queue is not empty. + frame: Frame | None + with self.mutex: + queued = [] + try: + while True: + queued.append(self.frames.get_nowait()) + except queue.Empty: + pass + for frame in frames: + self.frames.put(frame) + # This loop runs only when a race condition occurs. + for frame in queued: # pragma: no cover + self.frames.put(frame) + + def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: """ Read the next message. @@ -73,11 +109,14 @@ def get(self, timeout: float | None = None) -> Data: Args: timeout: If a timeout is provided and elapses before a complete message is received, :meth:`get` raises :exc:`TimeoutError`. + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. TimeoutError: If a timeout is provided and elapses before a complete message is received. @@ -89,40 +128,45 @@ def get(self, timeout: float | None = None) -> Data: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") + # Locking with get_in_progress ensures only one thread can get here. self.get_in_progress = True - # If the message_complete event isn't set yet, release the lock to - # allow put() to run and eventually set it. - # Locking with get_in_progress ensures only one thread can get here. - completed = self.message_complete.wait(timeout) + try: + deadline = Deadline(timeout) + + # First frame + frame = self.get_next_frame(deadline.timeout()) + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = self.get_next_frame(deadline.timeout()) + except TimeoutError: + # Put frames already received back into the queue + # so that future calls to get() can return them. + self.reset_queue(frames) + raise + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) - with self.mutex: + finally: self.get_in_progress = False - # Waiting for a complete message timed out. - if not completed: - raise TimeoutError(f"timed out in {timeout:.1f}s") - - # get() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - joiner: Data = b"" if self.decoder is None else "" - # mypy cannot figure out that chunks have the proper type. - message: Data = joiner.join(self.chunks) # type: ignore + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data - self.chunks = [] - assert self.chunks_queue is None - - assert not self.message_fetched.is_set() - self.message_fetched.set() - - return message - - def get_iter(self) -> Iterator[Data]: + def get_iter(self, decode: bool | None = None) -> Iterator[Data]: """ Stream the next message. @@ -135,10 +179,15 @@ def get_iter(self) -> Iterator[Data]: This method only makes sense for fragmented messages. If messages aren't fragmented, use :meth:`get` instead. + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`get` or :meth:`get_iter` - concurrently. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. """ with self.mutex: @@ -148,116 +197,81 @@ def get_iter(self) -> Iterator[Data]: if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") - chunks = self.chunks - self.chunks = [] - self.chunks_queue = cast( - # Remove quotes around type when dropping Python < 3.10. - "queue.SimpleQueue[Data | None]", - queue.SimpleQueue(), - ) - - # Sending None in chunk_queue supersedes setting message_complete - # when switching to "streaming". If message is already complete - # when the switch happens, put() didn't send None, so we have to. - if self.message_complete.is_set(): - self.chunks_queue.put(None) - + # Locking with get_in_progress ensures only one coroutine can get here. self.get_in_progress = True - # Locking with get_in_progress ensures only one thread can get here. - chunk: Data | None - for chunk in chunks: - yield chunk - while (chunk := self.chunks_queue.get()) is not None: - yield chunk + # Locking with get_in_progress prevents concurrent execution until + # get_iter() fetches a complete message or is cancelled. - with self.mutex: - self.get_in_progress = False + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. - # get_iter() was unblocked by close() rather than put(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_complete.is_set() - self.message_complete.clear() - - assert self.chunks == [] - self.chunks_queue = None + # First frame + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + frame = self.get_next_frame() + with self.mutex: + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data - assert not self.message_fetched.is_set() - self.message_fetched.set() + self.get_in_progress = False def put(self, frame: Frame) -> None: """ Add ``frame`` to the next message. - When ``frame`` is the final frame in a message, :meth:`put` waits until - the message is fetched, which can be achieved by calling :meth:`get` or - by fully consuming the return value of :meth:`get_iter`. - - :meth:`put` assumes that the stream of frames respects the protocol. If - it doesn't, the behavior is undefined. - Raises: EOFError: If the stream of frames has ended. - ConcurrencyError: If two threads run :meth:`put` concurrently. """ with self.mutex: if self.closed: raise EOFError("stream of frames ended") - if self.put_in_progress: - raise ConcurrencyError("put is already running") - - if frame.opcode is OP_TEXT: - self.decoder = UTF8Decoder(errors="strict") - elif frame.opcode is OP_BINARY: - self.decoder = None - else: - assert frame.opcode is OP_CONT - - data: Data - if self.decoder is not None: - data = self.decoder.decode(frame.data, frame.fin) - else: - data = frame.data - - if self.chunks_queue is None: - self.chunks.append(data) - else: - self.chunks_queue.put(data) - - if not frame.fin: - return - - # Message is complete. Wait until it's fetched to return. - - assert not self.message_complete.is_set() - self.message_complete.set() - - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - assert not self.message_fetched.is_set() - - self.put_in_progress = True - - # Release the lock to allow get() to run and eventually set the event. - # Locking with put_in_progress ensures only one coroutine can get here. - self.message_fetched.wait() - - with self.mutex: - self.put_in_progress = False - - # put() was unblocked by close() rather than get() or get_iter(). - if self.closed: - raise EOFError("stream of frames ended") - - assert self.message_fetched.is_set() - self.message_fetched.clear() - - self.decoder = None + self.frames.put(frame) + self.maybe_pause() + + # put() and get/get_iter() call maybe_pause() and maybe_resume() while + # holding self.mutex. This guarantees that the calls interleave properly. + # Specifically, it prevents a race condition where maybe_resume() would + # run before maybe_pause(), leaving the connection incorrectly paused. + + # A race condition is possible when get/get_iter() call self.frames.get() + # without holding self.mutex. However, it's harmless — and even beneficial! + # It can only result in popping an item from the queue before maybe_resume() + # runs and skipping a pause() - resume() cycle that would otherwise occur. + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + assert self.mutex.locked() + # Check for "> high" to support high = 0 + if self.frames.qsize() > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + assert self.mutex.locked() + # Check for "<= low" to support low = 0 + if self.frames.qsize() <= self.low and self.paused: + self.paused = False + self.resume() def close(self) -> None: """ @@ -273,12 +287,5 @@ def close(self) -> None: self.closed = True - # Unblock get or get_iter. - if self.get_in_progress: - self.message_complete.set() - if self.chunks_queue is not None: - self.chunks_queue.put(None) - - # Unblock put(). - if self.put_in_progress: - self.message_fetched.set() + # Unblock get() or get_iter(). + self.frames.put(None) diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index 464c4a17..94f76b65 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -51,10 +51,12 @@ class ServerConnection(Connection): :exc:`~websockets.exceptions.ConnectionClosedError` when the connection is closed with any other code. + The ``close_timeout`` and ``max_queue`` arguments have the same meaning as + in :func:`serve`. + Args: socket: Socket connected to a WebSocket client. protocol: Sans-I/O connection. - close_timeout: Timeout for closing the connection in seconds. """ @@ -64,6 +66,7 @@ def __init__( protocol: ServerProtocol, *, close_timeout: float | None = 10, + max_queue: int | tuple[int, int | None] = 16, ) -> None: self.protocol: ServerProtocol self.request_rcvd = threading.Event() @@ -71,6 +74,7 @@ def __init__( socket, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) self.username: str # see basic_auth() @@ -349,6 +353,7 @@ def serve( close_timeout: float | None = 10, # Limits max_size: int | None = 2**20, + max_queue: int | tuple[int, int | None] = 16, # Logging logger: LoggerLike | None = None, # Escape hatch for advanced customization @@ -427,6 +432,10 @@ def handler(websocket): :obj:`None` disables the timeout. max_size: Maximum size of incoming messages in bytes. :obj:`None` disables the limit. + max_queue: High-water mark of the buffer where frames are received. + It defaults to 16 frames. The low-water mark defaults to ``max_queue + // 4``. You may pass a ``(high, low)`` tuple to set the high-water + and low-water marks. logger: Logger for this server. It defaults to ``logging.getLogger("websockets.server")``. See the :doc:`logging guide <../../topics/logging>` for details. @@ -548,6 +557,7 @@ def protocol_select_subprotocol( sock, protocol, close_timeout=close_timeout, + max_queue=max_queue, ) except Exception: sock.close() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 563cf2b1..12e2bd5f 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -793,8 +793,8 @@ async def test_close_timeout_waiting_for_connection_closed(self): self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) async def test_close_does_not_wait_for_recv(self): - # The asyncio implementation has a buffer for incoming messages. Closing - # the connection discards buffered messages. This is allowed by the RFC: + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: # > However, there is no guarantee that the endpoint that has already # > sent a Close frame will continue to process data. await self.remote_connection.send("😀") @@ -1075,7 +1075,10 @@ async def test_max_queue(self): async def test_max_queue_tuple(self): """max_queue parameter configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=(4, 2)) + connection = Connection( + Protocol(self.LOCAL), + max_queue=(4, 2), + ) transport = Mock() connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) @@ -1083,14 +1086,20 @@ async def test_max_queue_tuple(self): async def test_write_limit(self): """write_limit parameter configures high-water mark of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=4096) + connection = Connection( + Protocol(self.LOCAL), + write_limit=4096, + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) async def test_write_limits(self): """write_limit parameter configures high and low-water marks of write buffer.""" - connection = Connection(Protocol(self.LOCAL), write_limit=(4096, 2048)) + connection = Connection( + Protocol(self.LOCAL), + write_limit=(4096, 2048), + ) transport = Mock() connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index d2cf25c9..2ff929d3 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -350,7 +350,7 @@ async def test_cancel_get_iter_before_first_frame(self): self.assertEqual(fragments, ["café"]) async def test_cancel_get_iter_after_first_frame(self): - """get cannot be canceled after reading the first frame.""" + """get_iter cannot be canceled after reading the first frame.""" self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) getter_task = asyncio.create_task(alist(self.assembler.get_iter())) @@ -429,7 +429,7 @@ async def test_get_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_fails_when_get_iter_is_running(self): """get cannot be called concurrently with get_iter.""" @@ -437,7 +437,7 @@ async def test_get_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await self.assembler.get() - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_is_running(self): """get_iter cannot be called concurrently with get.""" @@ -445,7 +445,7 @@ async def test_get_iter_fails_when_get_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate async def test_get_iter_fails_when_get_iter_is_running(self): """get_iter cannot be called concurrently.""" @@ -453,7 +453,7 @@ async def test_get_iter_fails_when_get_iter_is_running(self): await asyncio.sleep(0) with self.assertRaises(ConcurrencyError): await alist(self.assembler.get_iter()) - self.assembler.close() # let task terminate + self.assembler.put(Frame(OP_TEXT, b"")) # let task terminate # Test setting limits @@ -463,7 +463,7 @@ async def test_set_high_water_mark(self): self.assertEqual(assembler.high, 10) async def test_set_high_and_low_water_mark(self): - """high sets the high-water mark.""" + """high sets the high-water mark and low-water mark.""" assembler = Assembler(high=10, low=5) self.assertEqual(assembler.high, 10) self.assertEqual(assembler.low, 5) diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 87333fd3..db1cc8e9 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -154,6 +154,16 @@ def test_recv_binary(self): self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(decode=False), "😀".encode()) + + def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual(self.connection.recv(decode=True), "😀") + def test_recv_fragmented_text(self): """recv receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -228,6 +238,22 @@ def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) + def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + self.remote_connection.send("😀") + self.assertEqual( + list(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + self.remote_connection.send("😀".encode()) + self.assertEqual( + list(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + def test_recv_streaming_fragmented_text(self): """recv_streaming receives a fragmented text message.""" self.remote_connection.send(["😀", "😀"]) @@ -499,28 +525,17 @@ def test_close_timeout_waiting_for_connection_closed(self): # Remove socket.timeout when dropping Python < 3.10. self.assertIsInstance(exc.__cause__, (socket.timeout, TimeoutError)) - def test_close_waits_for_recv(self): - # The sync implementation doesn't have a buffer for incoming messsages. - # It requires reading incoming frames until the close frame is reached. - # This behavior — close() blocks until recv() is called — is less than - # ideal and inconsistent with the asyncio implementation. + def test_close_does_not_wait_for_recv(self): + # Closing the connection discards messages buffered in the assembler. + # This is allowed by the RFC: + # > However, there is no guarantee that the endpoint that has already + # > sent a Close frame will continue to process data. self.remote_connection.send("😀") + self.connection.close() close_thread = threading.Thread(target=self.connection.close) close_thread.start() - # Let close() initiate the closing handshake and send a close frame. - time.sleep(MS) - self.assertTrue(close_thread.is_alive()) - - # Connection isn't closed yet. - self.connection.recv() - - # Let close() receive a close frame and finish the closing handshake. - time.sleep(MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -528,24 +543,6 @@ def test_close_waits_for_recv(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) - def test_close_timeout_waiting_for_recv(self): - self.remote_connection.send("😀") - - close_thread = threading.Thread(target=self.connection.close) - close_thread.start() - - # Let close() time out during the closing handshake. - time.sleep(3 * MS) - self.assertFalse(close_thread.is_alive()) - - # Connection is closed now. - with self.assertRaises(ConnectionClosedError) as raised: - self.connection.recv() - - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") - self.assertIsInstance(exc.__cause__, TimeoutError) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -724,6 +721,45 @@ def test_pong_unsupported_type(self): with self.assertRaises(TypeError): self.connection.pong([]) + # Test parameters. + + def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) + self.assertEqual(connection.close_timeout, 42 * MS) + + def test_max_queue(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + + def test_max_queue_tuple(self): + """max_queue parameter configures high-water mark of frames buffer.""" + socket_, remote_socket = socket.socketpair() + self.addCleanup(socket_.close) + self.addCleanup(remote_socket.close) + connection = Connection( + socket_, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + # Test attributes. def test_id(self): diff --git a/tests/sync/test_messages.py b/tests/sync/test_messages.py index d44b39b8..02513894 100644 --- a/tests/sync/test_messages.py +++ b/tests/sync/test_messages.py @@ -1,4 +1,6 @@ import time +import unittest +import unittest.mock from websockets.exceptions import ConcurrencyError from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -9,66 +11,23 @@ class AssemblerTests(ThreadTestCase): - """ - Tests in this class interact a lot with hidden synchronization mechanisms: - - - get() / get_iter() and put() must run in separate threads when a final - frame is set because put() waits for get() / get_iter() to fetch the - message before returning. - - - run_in_thread() lets its target run before yielding back control on entry, - which guarantees the intended execution order of test cases. - - - run_in_thread() waits for its target to finish running before yielding - back control on exit, which allows making assertions immediately. - - - When the main thread performs actions that let another thread progress, it - must wait before making assertions, to avoid depending on scheduling. - - """ - def setUp(self): - self.assembler = Assembler() - - def tearDown(self): - """ - Check that the assembler goes back to its default state after each test. - - This removes the need for testing various sequences. - - """ - self.assertFalse(self.assembler.mutex.locked()) - self.assertFalse(self.assembler.get_in_progress) - self.assertFalse(self.assembler.put_in_progress) - if not self.assembler.closed: - self.assertFalse(self.assembler.message_complete.is_set()) - self.assertFalse(self.assembler.message_fetched.is_set()) - self.assertIsNone(self.assembler.decoder) - self.assertEqual(self.assembler.chunks, []) - self.assertIsNone(self.assembler.chunks_queue) + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) # Test get def test_get_text_message_already_received(self): """get returns a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_binary_message_already_received(self): """get returns a binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() self.assertEqual(message, b"tea") def test_get_text_message_not_received_yet(self): @@ -99,112 +58,145 @@ def getter(): def test_get_fragmented_text_message_already_received(self): """get reassembles a fragmented a text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = self.assembler.get() self.assertEqual(message, "café") def test_get_fragmented_binary_message_already_received(self): """get reassembles a fragmented binary message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - message = self.assembler.get() - + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = self.assembler.get() self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_being_received(self): - """get reassembles a fragmented text message that is partially received.""" + def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_being_received(self): - """get reassembles a fragmented binary message that is partially received.""" + def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" message = None def getter(): nonlocal message message = self.assembler.get() - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - def test_get_fragmented_text_message_not_received_yet(self): - """get reassembles a fragmented text message when it is received.""" + def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(message, "café") - def test_get_fragmented_binary_message_not_received_yet(self): - """get reassembles a fragmented binary message when it is received.""" + def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" message = None def getter(): nonlocal message message = self.assembler.get() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(message, b"tea") - # Test get_iter + def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + def test_get_resumes_reading(self): + """get resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + self.assembler.get() + self.resume.assert_called_once_with() + + def test_get_timeout_before_first_frame(self): + """get times out before reading the first frame.""" + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_text_message_already_received(self): - """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_timeout_after_first_frame(self): + """get times out after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assertEqual(fragments, ["café"]) + with self.assertRaises(TimeoutError): + self.assembler.get(timeout=MS) - def test_get_iter_binary_message_already_received(self): - """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"tea")) + message = self.assembler.get() + self.assertEqual(message, "café") - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + # Test get_iter + def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, [b"tea"]) def test_get_iter_text_message_not_received_yet(self): @@ -212,6 +204,7 @@ def test_get_iter_text_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -225,6 +218,7 @@ def test_get_iter_binary_message_not_received_yet(self): fragments = [] def getter(): + nonlocal fragments for fragment in self.assembler.get_iter(): fragments.append(fragment) @@ -235,121 +229,112 @@ def getter(): def test_get_iter_fragmented_text_message_already_received(self): """get_iter yields a fragmented text message that is already received.""" - - def putter(): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) - + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter()) self.assertEqual(fragments, ["ca", "f", "é"]) def test_get_iter_fragmented_binary_message_already_received(self): """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) - def putter(): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - self.assembler.put(Frame(OP_CONT, b"a")) - - with self.run_in_thread(putter): - fragments = list(self.assembler.get_iter()) + def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") - self.assertEqual(fragments, [b"t", b"e", b"a"]) + def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = self.assembler.get_iter() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) - - self.assertEqual(fragments, ["ca", "f", "é"]) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(next(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(next(iterator), "é") def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - with self.run_in_thread(getter): - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) - - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - def test_get_iter_fragmented_text_message_not_received_yet(self): - """get_iter yields a fragmented text message when it is received.""" - fragments = [] - - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + iterator = self.assembler.get_iter() + self.assertEqual(next(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(next(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(next(iterator), b"a") + + def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = list(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca"]) - self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, ["ca", "f"]) - self.assembler.put(Frame(OP_CONT, b"\xa9")) + def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = list(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) - self.assertEqual(fragments, ["ca", "f", "é"]) + def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the high-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) - def test_get_iter_fragmented_binary_message_not_received_yet(self): - """get_iter yields a fragmented binary message when it is received.""" - fragments = [] + iterator = self.assembler.get_iter() - def getter(): - for fragment in self.assembler.get_iter(): - fragments.append(fragment) + # queue is above the low-water mark + next(iterator) + self.resume.assert_not_called() - with self.run_in_thread(getter): - self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t"]) - self.assembler.put(Frame(OP_CONT, b"e", fin=False)) - time.sleep(MS) - self.assertEqual(fragments, [b"t", b"e"]) - self.assembler.put(Frame(OP_CONT, b"a")) + # queue is at the low-water mark + next(iterator) + self.resume.assert_called_once_with() - self.assertEqual(fragments, [b"t", b"e", b"a"]) - - # Test timeouts + # queue is below the low-water mark + next(iterator) + self.resume.assert_called_once_with() - def test_get_with_timeout_completes(self): - """get returns a message when it is received before the timeout.""" + # Test put - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) - - with self.run_in_thread(putter): - message = self.assembler.get(MS) + def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() - self.assertEqual(message, "café") + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() - def test_get_with_timeout_times_out(self): - """get raises TimeoutError when no message is received before the timeout.""" - with self.assertRaises(TimeoutError): - self.assembler.get(MS) + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() # Test termination @@ -373,18 +358,8 @@ def closer(): with self.run_in_thread(closer): with self.assertRaises(EOFError): - list(self.assembler.get_iter()) - - def test_put_fails_when_interrupted_by_close(self): - """put raises EOFError when close is called.""" - - def closer(): - time.sleep(2 * MS) - self.assembler.close() - - with self.run_in_thread(closer): - with self.assertRaises(EOFError): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_get_fails_after_close(self): """get raises EOFError after close is called.""" @@ -396,7 +371,8 @@ def test_get_iter_fails_after_close(self): """get_iter raises EOFError after close is called.""" self.assembler.close() with self.assertRaises(EOFError): - list(self.assembler.get_iter()) + for _ in self.assembler.get_iter(): + self.fail("no fragment expected") def test_put_fails_after_close(self): """put raises EOFError after close is called.""" @@ -439,13 +415,25 @@ def test_get_iter_fails_when_get_iter_is_running(self): list(self.assembler.get_iter()) self.assembler.put(Frame(OP_TEXT, b"")) # unlock other thread - def test_put_fails_when_put_is_running(self): - """put cannot be called concurrently.""" + # Test setting limits - def putter(): - self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + def test_set_high_water_mark(self): + """high sets the high-water mark.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) - with self.run_in_thread(putter): - with self.assertRaises(ConcurrencyError): - self.assembler.put(Frame(OP_BINARY, b"tea")) - self.assembler.get() # unblock other thread + def test_set_high_and_low_water_mark(self): + """high sets the high-water mark and low-water mark.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5)