Skip to content

Commit

Permalink
Rewrite sync Assembler to improve performance.
Browse files Browse the repository at this point in the history
Previously, a latch was used to synchronize the user thread reading messages and
the background thread reading from the network. This required two thread switches
per message.

Now, the background thread writes messages to queue, from which the user thread
reads. This allows passing several frames at each thread switch, reducing the
overhead.

With this server code::

    async def test(websocket):
        for i in range(int(await websocket.recv())):
            await websocket.send(f"{{\"iteration\": {i}}}")

and this client code::

    with connect("ws://localhost:8765", compression=None) as websocket:
        websocket.send("1_000_000")
        for message in websocket:
            pass

an unscientific benchmark (running it on my laptop) shows a 2.5x speedup, going
from 11 seconds to 4.4 seconds. Setting a very large recv_bufsize and max_size doesn't yield significant further improvement.

The new implementation mirrors the asyncio implementation and gains the
option to prevent or force decoding of frames. Refs #1376.
  • Loading branch information
aaugustin committed Oct 25, 2024
1 parent e5182c9 commit fa78d82
Show file tree
Hide file tree
Showing 12 changed files with 586 additions and 467 deletions.
16 changes: 15 additions & 1 deletion docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,21 @@ Backwards-incompatible changes
If you wrote an :class:`extension <extensions.Extension>` that relies on
methods not provided by these new types, you may need to update your code.

New features
............

* Added an option to receive text frames as :class:`bytes`, without decoding,
in the :mod:`threading` implementation; also binary frames as :class:`str`.

* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio`
and :mod:`threading` implementations, as well as :class:`str` a binary frame.

Improvements
............

* Sending or receiving large compressed frames is now faster.
* Sending or receiving large compressed messages is now faster.

* The :mod:`threading` implementation receives messages faster.

.. _13.1:

Expand Down Expand Up @@ -198,6 +209,9 @@ New features

* Validated compatibility with Python 3.12 and 3.13.

* Added an option to receive text frames as :class:`bytes`, without decoding,
in the :mod:`asyncio` implementation; also binary frames as :class:`str`.

* Added :doc:`environment variables <../reference/variables>` to configure debug
logs, the ``Server`` and ``User-Agent`` headers, as well as security limits.

Expand Down
2 changes: 1 addition & 1 deletion src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ClientConnection(Connection):
closed with any other code.
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
and ``write_limit`` arguments the same meaning as in :func:`connect`.
and ``write_limit`` arguments have the same meaning as in :func:`connect`.
Args:
protocol: Sans-I/O connection.
Expand Down
60 changes: 33 additions & 27 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None:
self.queue.extend(items)

def abort(self) -> None:
"""Close the queue, raising EOFError in get() if necessary."""
if self.get_waiter is not None and not self.get_waiter.done():
self.get_waiter.set_exception(EOFError("stream of frames ended"))
# Clear the queue to avoid storing unnecessary data in memory.
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__( # pragma: no cover
pause: Callable[[], Any] = lambda: None,
resume: Callable[[], Any] = lambda: None,
) -> None:
# Queue of incoming messages. Each item is a queue of frames.
# Queue of incoming frames.
self.frames: SimpleQueue[Frame] = SimpleQueue()

# We cannot put a hard limit on the size of the queue because a single
Expand Down Expand Up @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data:
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True

# First frame
# Locking with get_in_progress prevents concurrent execution until
# get() fetches a complete message or is cancelled.

try:
# First frame
frame = await self.frames.get()
except asyncio.CancelledError:
self.get_in_progress = False
raise
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
frames = [frame]

# Following frames, for fragmented messages
while not frame.fin:
try:
frame = await self.frames.get()
except asyncio.CancelledError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
self.frames.reset(frames)
self.get_in_progress = False
raise
self.maybe_resume()
assert frame.opcode is OP_CONT
frames.append(frame)

self.get_in_progress = False
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
frames = [frame]

# Following frames, for fragmented messages
while not frame.fin:
try:
frame = await self.frames.get()
except asyncio.CancelledError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
self.frames.reset(frames)
raise
self.maybe_resume()
assert frame.opcode is OP_CONT
frames.append(frame)

finally:
self.get_in_progress = False

data = b"".join(frame.data for frame in frames)
if decode:
Expand Down Expand Up @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution until
# get_iter() fetches a complete message or is cancelled.

# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.

# First frame
try:
frame = await self.frames.get()
Expand All @@ -233,6 +238,7 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
# here will leave the assembler in a stuck state. Future calls to
# get() or get_iter() will raise ConcurrencyError.
frame = await self.frames.get()

self.maybe_resume()
assert frame.opcode is OP_CONT
if decode:
Expand Down
2 changes: 1 addition & 1 deletion src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ServerConnection(Connection):
closed with any other code.
The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
and ``write_limit`` arguments the same meaning as in :func:`serve`.
and ``write_limit`` arguments have the same meaning as in :func:`serve`.
Args:
protocol: Sans-I/O connection.
Expand Down
12 changes: 11 additions & 1 deletion src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ class ClientConnection(Connection):
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.
The ``close_timeout`` and ``max_queue`` arguments have the same meaning as
in :func:`connect`.
Args:
socket: Socket connected to a WebSocket server.
protocol: Sans-I/O connection.
close_timeout: Timeout for closing the connection in seconds.
"""

Expand All @@ -53,13 +55,15 @@ def __init__(
protocol: ClientProtocol,
*,
close_timeout: float | None = 10,
max_queue: int | tuple[int, int | None] = 16,
) -> None:
self.protocol: ClientProtocol
self.response_rcvd = threading.Event()
super().__init__(
socket,
protocol,
close_timeout=close_timeout,
max_queue=max_queue,
)

def handshake(
Expand Down Expand Up @@ -135,6 +139,7 @@ def connect(
close_timeout: float | None = 10,
# Limits
max_size: int | None = 2**20,
max_queue: int | tuple[int, int | None] = 16,
# Logging
logger: LoggerLike | None = None,
# Escape hatch for advanced customization
Expand Down Expand Up @@ -183,6 +188,10 @@ def connect(
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
:obj:`None` disables the limit.
max_queue: High-water mark of the buffer where frames are received.
It defaults to 16 frames. The low-water mark defaults to ``max_queue
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
and low-water marks.
logger: Logger for this client.
It defaults to ``logging.getLogger("websockets.client")``.
See the :doc:`logging guide <../../topics/logging>` for details.
Expand Down Expand Up @@ -287,6 +296,7 @@ def connect(
sock,
protocol,
close_timeout=close_timeout,
max_queue=max_queue,
)
except Exception:
if sock is not None:
Expand Down
79 changes: 59 additions & 20 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ def __init__(
protocol: Protocol,
*,
close_timeout: float | None = 10,
max_queue: int | tuple[int, int | None] = 16,
) -> None:
self.socket = socket
self.protocol = protocol
self.close_timeout = close_timeout
if isinstance(max_queue, int):
max_queue = (max_queue, None)
self.max_queue = max_queue

# Inject reference to this instance in the protocol's logger.
self.protocol.logger = logging.LoggerAdapter(
Expand All @@ -76,8 +80,15 @@ def __init__(
# Mutex serializing interactions with the protocol.
self.protocol_mutex = threading.Lock()

# Lock stopping reads when the assembler buffer is full.
self.recv_flow_control = threading.Lock()

# Assembler turning frames into messages and serializing reads.
self.recv_messages = Assembler()
self.recv_messages = Assembler(
*self.max_queue,
pause=self.recv_flow_control.acquire,
resume=self.recv_flow_control.release,
)

# Whether we are busy sending a fragmented message.
self.send_in_progress = False
Expand All @@ -88,6 +99,10 @@ def __init__(
# Mapping of ping IDs to pong waiters, in chronological order.
self.ping_waiters: dict[bytes, threading.Event] = {}

# Exception raised in recv_events, to be chained to ConnectionClosed
# in the user thread in order to show why the TCP connection dropped.
self.recv_exc: BaseException | None = None

# Receiving events from the socket. This thread is marked as daemon to
# allow creating a connection in a non-daemon thread and using it in a
# daemon thread. This mustn't prevent the interpreter from exiting.
Expand All @@ -97,10 +112,6 @@ def __init__(
)
self.recv_events_thread.start()

# Exception raised in recv_events, to be chained to ConnectionClosed
# in the user thread in order to show why the TCP connection dropped.
self.recv_exc: BaseException | None = None

# Public attributes

@property
Expand Down Expand Up @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]:
except ConnectionClosedOK:
return

def recv(self, timeout: float | None = None) -> Data:
def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Receive the next message.
Expand All @@ -191,21 +202,36 @@ def recv(self, timeout: float | None = None) -> Data:
If the message is fragmented, wait until all fragments are received,
reassemble them, and return the whole message.
Args:
timeout: Timeout for receiving a message in seconds.
decode: Set this flag to override the default behavior of returning
:class:`str` or :class:`bytes`. See below for details.
Returns:
A string (:class:`str`) for a Text_ frame or a bytestring
(:class:`bytes`) for a Binary_ frame.
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
You may override this behavior with the ``decode`` argument:
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
return a bytestring (:class:`bytes`). This improves performance
when decoding isn't needed, for example if the message contains
JSON and you're using a JSON library that expects a bytestring.
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
and return a string (:class:`str`). This may be useful for
servers that send binary frames instead of text frames.
Raises:
ConnectionClosed: When the connection is closed.
ConcurrencyError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
try:
return self.recv_messages.get(timeout)
return self.recv_messages.get(timeout, decode)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
Expand All @@ -216,31 +242,47 @@ def recv(self, timeout: float | None = None) -> Data:
"is already running recv or recv_streaming"
) from None

def recv_streaming(self) -> Iterator[Data]:
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]:
"""
Receive the next message frame by frame.
If the message is fragmented, yield each fragment as it is received.
The iterator must be fully consumed, or else the connection will become
This method is designed for receiving fragmented messages. It returns an
iterator that yields each fragment as it is received. This iterator must
be fully consumed. Else, future calls to :meth:`recv` or
:meth:`recv_streaming` will raise
:exc:`~websockets.exceptions.ConcurrencyError`, making the connection
unusable.
:meth:`recv_streaming` raises the same exceptions as :meth:`recv`.
Args:
decode: Set this flag to override the default behavior of returning
:class:`str` or :class:`bytes`. See below for details.
Returns:
An iterator of strings (:class:`str`) for a Text_ frame or
bytestrings (:class:`bytes`) for a Binary_ frame.
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
You may override this behavior with the ``decode`` argument:
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames
and return bytestrings (:class:`bytes`). This may be useful to
optimize performance when decoding isn't needed.
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
and return strings (:class:`str`). This is useful for servers
that send binary frames instead of text frames.
Raises:
ConnectionClosed: When the connection is closed.
ConcurrencyError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.
"""
try:
yield from self.recv_messages.get_iter()
yield from self.recv_messages.get_iter(decode)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
Expand Down Expand Up @@ -571,8 +613,9 @@ def recv_events(self) -> None:
try:
while True:
try:
if self.close_deadline is not None:
self.socket.settimeout(self.close_deadline.timeout())
with self.recv_flow_control:
if self.close_deadline is not None:
self.socket.settimeout(self.close_deadline.timeout())
data = self.socket.recv(self.recv_bufsize)
except Exception as exc:
if self.debug:
Expand Down Expand Up @@ -622,13 +665,9 @@ def recv_events(self) -> None:
# Given that automatic responses write small amounts of data,
# this should be uncommon, so we don't handle the edge case.

try:
for event in events:
# This may raise EOFError if the closing handshake
# times out while a message is waiting to be read.
self.process_event(event)
except EOFError:
break
for event in events:
# This isn't expected to raise an exception.
self.process_event(event)

# Breaking out of the while True: ... loop means that we believe
# that the socket doesn't work anymore.
Expand Down
Loading

0 comments on commit fa78d82

Please sign in to comment.