Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite sync Assembler to improve performance. #1530

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,21 @@ Backwards-incompatible changes
If you wrote an :class:`extension <extensions.Extension>` that relies on
methods not provided by these new types, you may need to update your code.

New features
............

* Added an option to receive text frames as :class:`bytes`, without decoding,
in the :mod:`threading` implementation; also binary frames as :class:`str`.

* Added an option to send :class:`bytes` as a text frame in the :mod:`asyncio`
and :mod:`threading` implementations, as well as :class:`str` a binary frame.

Improvements
............

* Sending or receiving large compressed frames is now faster.
* The :mod:`threading` implementation receives messages faster.

* Sending or receiving large compressed messages is now faster.

.. _13.1:

Expand Down Expand Up @@ -198,6 +209,9 @@ New features

* Validated compatibility with Python 3.12 and 3.13.

* Added an option to receive text frames as :class:`bytes`, without decoding,
in the :mod:`asyncio` implementation; also binary frames as :class:`str`.

* Added :doc:`environment variables <../reference/variables>` to configure debug
logs, the ``Server`` and ``User-Agent`` headers, as well as security limits.

Expand Down
2 changes: 1 addition & 1 deletion src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ClientConnection(Connection):
closed with any other code.

The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
and ``write_limit`` arguments the same meaning as in :func:`connect`.
and ``write_limit`` arguments have the same meaning as in :func:`connect`.

Args:
protocol: Sans-I/O connection.
Expand Down
59 changes: 32 additions & 27 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def reset(self, items: Iterable[T]) -> None:
self.queue.extend(items)

def abort(self) -> None:
"""Close the queue, raising EOFError in get() if necessary."""
if self.get_waiter is not None and not self.get_waiter.done():
self.get_waiter.set_exception(EOFError("stream of frames ended"))
# Clear the queue to avoid storing unnecessary data in memory.
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__( # pragma: no cover
pause: Callable[[], Any] = lambda: None,
resume: Callable[[], Any] = lambda: None,
) -> None:
# Queue of incoming messages. Each item is a queue of frames.
# Queue of incoming frames.
self.frames: SimpleQueue[Frame] = SimpleQueue()

# We cannot put a hard limit on the size of the queue because a single
Expand Down Expand Up @@ -140,36 +141,35 @@ async def get(self, decode: bool | None = None) -> Data:
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True

# First frame
# Locking with get_in_progress prevents concurrent execution until
# get() fetches a complete message or is cancelled.

try:
# First frame
frame = await self.frames.get()
except asyncio.CancelledError:
self.get_in_progress = False
raise
self.maybe_resume()
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
frames = [frame]

# Following frames, for fragmented messages
while not frame.fin:
try:
frame = await self.frames.get()
except asyncio.CancelledError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
self.frames.reset(frames)
self.get_in_progress = False
raise
self.maybe_resume()
assert frame.opcode is OP_CONT
frames.append(frame)

self.get_in_progress = False
assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY
if decode is None:
decode = frame.opcode is OP_TEXT
frames = [frame]

# Following frames, for fragmented messages
while not frame.fin:
try:
frame = await self.frames.get()
except asyncio.CancelledError:
# Put frames already received back into the queue
# so that future calls to get() can return them.
self.frames.reset(frames)
raise
self.maybe_resume()
assert frame.opcode is OP_CONT
frames.append(frame)

finally:
self.get_in_progress = False

data = b"".join(frame.data for frame in frames)
if decode:
Expand Down Expand Up @@ -207,9 +207,14 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
if self.get_in_progress:
raise ConcurrencyError("get() or get_iter() is already running")

# Locking with get_in_progress ensures only one coroutine can get here.
self.get_in_progress = True

# Locking with get_in_progress prevents concurrent execution until
# get_iter() fetches a complete message or is cancelled.

# If get_iter() raises an exception e.g. in decoder.decode(),
# get_in_progress remains set and the connection becomes unusable.

# First frame
try:
frame = await self.frames.get()
Expand Down
2 changes: 1 addition & 1 deletion src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ServerConnection(Connection):
closed with any other code.

The ``ping_interval``, ``ping_timeout``, ``close_timeout``, ``max_queue``,
and ``write_limit`` arguments the same meaning as in :func:`serve`.
and ``write_limit`` arguments have the same meaning as in :func:`serve`.

Args:
protocol: Sans-I/O connection.
Expand Down
12 changes: 11 additions & 1 deletion src/websockets/sync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ class ClientConnection(Connection):
:exc:`~websockets.exceptions.ConnectionClosedError` when the connection is
closed with any other code.

The ``close_timeout`` and ``max_queue`` arguments have the same meaning as
in :func:`connect`.

Args:
socket: Socket connected to a WebSocket server.
protocol: Sans-I/O connection.
close_timeout: Timeout for closing the connection in seconds.

"""

Expand All @@ -53,13 +55,15 @@ def __init__(
protocol: ClientProtocol,
*,
close_timeout: float | None = 10,
max_queue: int | tuple[int, int | None] = 16,
) -> None:
self.protocol: ClientProtocol
self.response_rcvd = threading.Event()
super().__init__(
socket,
protocol,
close_timeout=close_timeout,
max_queue=max_queue,
)

def handshake(
Expand Down Expand Up @@ -135,6 +139,7 @@ def connect(
close_timeout: float | None = 10,
# Limits
max_size: int | None = 2**20,
max_queue: int | tuple[int, int | None] = 16,
# Logging
logger: LoggerLike | None = None,
# Escape hatch for advanced customization
Expand Down Expand Up @@ -183,6 +188,10 @@ def connect(
:obj:`None` disables the timeout.
max_size: Maximum size of incoming messages in bytes.
:obj:`None` disables the limit.
max_queue: High-water mark of the buffer where frames are received.
It defaults to 16 frames. The low-water mark defaults to ``max_queue
// 4``. You may pass a ``(high, low)`` tuple to set the high-water
and low-water marks.
logger: Logger for this client.
It defaults to ``logging.getLogger("websockets.client")``.
See the :doc:`logging guide <../../topics/logging>` for details.
Expand Down Expand Up @@ -287,6 +296,7 @@ def connect(
sock,
protocol,
close_timeout=close_timeout,
max_queue=max_queue,
)
except Exception:
if sock is not None:
Expand Down
79 changes: 59 additions & 20 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ def __init__(
protocol: Protocol,
*,
close_timeout: float | None = 10,
max_queue: int | tuple[int, int | None] = 16,
) -> None:
self.socket = socket
self.protocol = protocol
self.close_timeout = close_timeout
if isinstance(max_queue, int):
max_queue = (max_queue, None)
self.max_queue = max_queue

# Inject reference to this instance in the protocol's logger.
self.protocol.logger = logging.LoggerAdapter(
Expand All @@ -76,8 +80,15 @@ def __init__(
# Mutex serializing interactions with the protocol.
self.protocol_mutex = threading.Lock()

# Lock stopping reads when the assembler buffer is full.
self.recv_flow_control = threading.Lock()

# Assembler turning frames into messages and serializing reads.
self.recv_messages = Assembler()
self.recv_messages = Assembler(
*self.max_queue,
pause=self.recv_flow_control.acquire,
resume=self.recv_flow_control.release,
)

# Whether we are busy sending a fragmented message.
self.send_in_progress = False
Expand All @@ -88,6 +99,10 @@ def __init__(
# Mapping of ping IDs to pong waiters, in chronological order.
self.ping_waiters: dict[bytes, threading.Event] = {}

# Exception raised in recv_events, to be chained to ConnectionClosed
# in the user thread in order to show why the TCP connection dropped.
self.recv_exc: BaseException | None = None

# Receiving events from the socket. This thread is marked as daemon to
# allow creating a connection in a non-daemon thread and using it in a
# daemon thread. This mustn't prevent the interpreter from exiting.
Expand All @@ -97,10 +112,6 @@ def __init__(
)
self.recv_events_thread.start()

# Exception raised in recv_events, to be chained to ConnectionClosed
# in the user thread in order to show why the TCP connection dropped.
self.recv_exc: BaseException | None = None

# Public attributes

@property
Expand Down Expand Up @@ -172,7 +183,7 @@ def __iter__(self) -> Iterator[Data]:
except ConnectionClosedOK:
return

def recv(self, timeout: float | None = None) -> Data:
def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Receive the next message.

Expand All @@ -191,21 +202,36 @@ def recv(self, timeout: float | None = None) -> Data:
If the message is fragmented, wait until all fragments are received,
reassemble them, and return the whole message.

Args:
timeout: Timeout for receiving a message in seconds.
decode: Set this flag to override the default behavior of returning
:class:`str` or :class:`bytes`. See below for details.

Returns:
A string (:class:`str`) for a Text_ frame or a bytestring
(:class:`bytes`) for a Binary_ frame.

.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6

You may override this behavior with the ``decode`` argument:

* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
return a bytestring (:class:`bytes`). This improves performance
when decoding isn't needed, for example if the message contains
JSON and you're using a JSON library that expects a bytestring.
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
and return a string (:class:`str`). This may be useful for
servers that send binary frames instead of text frames.

Raises:
ConnectionClosed: When the connection is closed.
ConcurrencyError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.

"""
try:
return self.recv_messages.get(timeout)
return self.recv_messages.get(timeout, decode)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
Expand All @@ -216,31 +242,47 @@ def recv(self, timeout: float | None = None) -> Data:
"is already running recv or recv_streaming"
) from None

def recv_streaming(self) -> Iterator[Data]:
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]:
"""
Receive the next message frame by frame.

If the message is fragmented, yield each fragment as it is received.
The iterator must be fully consumed, or else the connection will become
This method is designed for receiving fragmented messages. It returns an
iterator that yields each fragment as it is received. This iterator must
be fully consumed. Else, future calls to :meth:`recv` or
:meth:`recv_streaming` will raise
:exc:`~websockets.exceptions.ConcurrencyError`, making the connection
unusable.

:meth:`recv_streaming` raises the same exceptions as :meth:`recv`.

Args:
decode: Set this flag to override the default behavior of returning
:class:`str` or :class:`bytes`. See below for details.

Returns:
An iterator of strings (:class:`str`) for a Text_ frame or
bytestrings (:class:`bytes`) for a Binary_ frame.

.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6

You may override this behavior with the ``decode`` argument:

* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames
and return bytestrings (:class:`bytes`). This may be useful to
optimize performance when decoding isn't needed.
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
and return strings (:class:`str`). This is useful for servers
that send binary frames instead of text frames.

Raises:
ConnectionClosed: When the connection is closed.
ConcurrencyError: If two threads call :meth:`recv` or
:meth:`recv_streaming` concurrently.

"""
try:
yield from self.recv_messages.get_iter()
yield from self.recv_messages.get_iter(decode)
except EOFError:
# Wait for the protocol state to be CLOSED before accessing close_exc.
self.recv_events_thread.join()
Expand Down Expand Up @@ -571,8 +613,9 @@ def recv_events(self) -> None:
try:
while True:
try:
if self.close_deadline is not None:
self.socket.settimeout(self.close_deadline.timeout())
with self.recv_flow_control:
if self.close_deadline is not None:
self.socket.settimeout(self.close_deadline.timeout())
data = self.socket.recv(self.recv_bufsize)
except Exception as exc:
if self.debug:
Expand Down Expand Up @@ -622,13 +665,9 @@ def recv_events(self) -> None:
# Given that automatic responses write small amounts of data,
# this should be uncommon, so we don't handle the edge case.

try:
for event in events:
# This may raise EOFError if the closing handshake
# times out while a message is waiting to be read.
self.process_event(event)
except EOFError:
break
for event in events:
# This isn't expected to raise an exception.
self.process_event(event)

# Breaking out of the while True: ... loop means that we believe
# that the socket doesn't work anymore.
Expand Down
Loading