Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to force sending text or binary frames. #1516

Merged
merged 1 commit into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 70 additions & 54 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data:

You may override this behavior with the ``decode`` argument:

* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames
and return a bytestring (:class:`bytes`). This may be useful to
optimize performance when decoding isn't needed.
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
return a bytestring (:class:`bytes`). This improves performance
when decoding isn't needed, for example if the message contains
JSON and you're using a JSON library that expects a bytestring.
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
and return a string (:class:`str`). This is useful for servers
that send binary frames instead of text frames.
and return a string (:class:`str`). This may be useful for
servers that send binary frames instead of text frames.

Raises:
ConnectionClosed: When the connection is closed.
Expand Down Expand Up @@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data
"is already running recv or recv_streaming"
) from None

async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None:
async def send(
self,
message: Data | Iterable[Data] | AsyncIterable[Data],
text: bool | None = None,
) -> None:
"""
Send a message.

Expand All @@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6

You may override this behavior with the ``text`` argument:

* Set ``text=True`` to send a bytestring or bytes-like object
(:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a
Text_ frame. This improves performance when the message is already
UTF-8 encoded, for example if the message contains JSON and you're
using a JSON library that produces a bytestring.
* Set ``text=False`` to send a string (:class:`str`) in a Binary_
frame. This may be useful for servers that expect binary frames
instead of text frames.

:meth:`send` also accepts an iterable or an asynchronous iterable of
strings, bytestrings, or bytes-like objects to enable fragmentation_.
Each item is treated as a message fragment and sent in its own frame.
Expand Down Expand Up @@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
# strings and bytes-like objects are iterable.

if isinstance(message, str):
async with self.send_context():
self.protocol.send_text(message.encode())
if text is False:
async with self.send_context():
self.protocol.send_binary(message.encode())
else:
async with self.send_context():
self.protocol.send_text(message.encode())

elif isinstance(message, BytesLike):
async with self.send_context():
self.protocol.send_binary(message)
if text is True:
async with self.send_context():
self.protocol.send_text(message)
else:
async with self.send_context():
self.protocol.send_binary(message)

# Catch a common mistake -- passing a dict to send().

Expand All @@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
try:
# First fragment.
if isinstance(chunk, str):
text = True
async with self.send_context():
self.protocol.send_text(
chunk.encode(),
fin=False,
)
if text is False:
async with self.send_context():
self.protocol.send_binary(chunk.encode(), fin=False)
else:
async with self.send_context():
self.protocol.send_text(chunk.encode(), fin=False)
encode = True
elif isinstance(chunk, BytesLike):
text = False
async with self.send_context():
self.protocol.send_binary(
chunk,
fin=False,
)
if text is True:
async with self.send_context():
self.protocol.send_text(chunk, fin=False)
else:
async with self.send_context():
self.protocol.send_binary(chunk, fin=False)
encode = False
else:
raise TypeError("iterable must contain bytes or str")

# Other fragments
for chunk in chunks:
if isinstance(chunk, str) and text:
if isinstance(chunk, str) and encode:
async with self.send_context():
self.protocol.send_continuation(
chunk.encode(),
fin=False,
)
elif isinstance(chunk, BytesLike) and not text:
self.protocol.send_continuation(chunk.encode(), fin=False)
elif isinstance(chunk, BytesLike) and not encode:
async with self.send_context():
self.protocol.send_continuation(
chunk,
fin=False,
)
self.protocol.send_continuation(chunk, fin=False)
else:
raise TypeError("iterable must contain uniform types")

Expand Down Expand Up @@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
try:
# First fragment.
if isinstance(chunk, str):
text = True
async with self.send_context():
self.protocol.send_text(
chunk.encode(),
fin=False,
)
if text is False:
async with self.send_context():
self.protocol.send_binary(chunk.encode(), fin=False)
else:
async with self.send_context():
self.protocol.send_text(chunk.encode(), fin=False)
encode = True
elif isinstance(chunk, BytesLike):
text = False
async with self.send_context():
self.protocol.send_binary(
chunk,
fin=False,
)
if text is True:
async with self.send_context():
self.protocol.send_text(chunk, fin=False)
else:
async with self.send_context():
self.protocol.send_binary(chunk, fin=False)
encode = False
else:
raise TypeError("async iterable must contain bytes or str")

# Other fragments
async for chunk in achunks:
if isinstance(chunk, str) and text:
if isinstance(chunk, str) and encode:
async with self.send_context():
self.protocol.send_continuation(
chunk.encode(),
fin=False,
)
elif isinstance(chunk, BytesLike) and not text:
self.protocol.send_continuation(chunk.encode(), fin=False)
elif isinstance(chunk, BytesLike) and not encode:
async with self.send_context():
self.protocol.send_continuation(
chunk,
fin=False,
)
self.protocol.send_continuation(chunk, fin=False)
else:
raise TypeError("async iterable must contain uniform types")

Expand Down
72 changes: 64 additions & 8 deletions tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ async def test_recv_binary(self):
await self.remote_connection.send(b"\x01\x02\xfe\xff")
self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff")

async def test_recv_encoded_text(self):
"""recv receives an UTF-8 encoded text message."""
async def test_recv_text_as_bytes(self):
"""recv receives a text message as bytes."""
await self.remote_connection.send("😀")
self.assertEqual(await self.connection.recv(decode=False), "😀".encode())

async def test_recv_decoded_binary(self):
"""recv receives an UTF-8 decoded binary message."""
async def test_recv_binary_as_text(self):
"""recv receives a binary message as a str."""
await self.remote_connection.send("😀".encode())
self.assertEqual(await self.connection.recv(decode=True), "😀")

Expand Down Expand Up @@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self):
[b"\x01\x02\xfe\xff"],
)

async def test_recv_streaming_encoded_text(self):
"""recv_streaming receives an UTF-8 encoded text message."""
async def test_recv_streaming_text_as_bytes(self):
"""recv_streaming receives a text message as bytes."""
await self.remote_connection.send("😀")
self.assertEqual(
await alist(self.connection.recv_streaming(decode=False)),
["😀".encode()],
)

async def test_recv_streaming_decoded_binary(self):
"""recv_streaming receives a UTF-8 decoded binary message."""
async def test_recv_streaming_binary_as_str(self):
"""recv_streaming receives a binary message as a str."""
await self.remote_connection.send("😀".encode())
self.assertEqual(
await alist(self.connection.recv_streaming(decode=True)),
Expand Down Expand Up @@ -438,6 +438,16 @@ async def test_send_binary(self):
await self.connection.send(b"\x01\x02\xfe\xff")
self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff")

async def test_send_binary_from_str(self):
"""send sends a binary message from a str."""
await self.connection.send("😀", text=False)
self.assertEqual(await self.remote_connection.recv(), "😀".encode())

async def test_send_text_from_bytes(self):
"""send sends a text message from bytes."""
await self.connection.send("😀".encode(), text=True)
self.assertEqual(await self.remote_connection.recv(), "😀")

async def test_send_fragmented_text(self):
"""send sends a fragmented text message."""
await self.connection.send(["😀", "😀"])
Expand All @@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self):
[b"\x01\x02", b"\xfe\xff", b""],
)

async def test_send_fragmented_binary_from_str(self):
"""send sends a fragmented binary message from a str."""
await self.connection.send(["😀", "😀"], text=False)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀".encode(), "😀".encode(), b""],
)

async def test_send_fragmented_text_from_bytes(self):
"""send sends a fragmented text message from bytes."""
await self.connection.send(["😀".encode(), "😀".encode()], text=True)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀", "😀", ""],
)

async def test_send_async_fragmented_text(self):
"""send sends a fragmented text message asynchronously."""

Expand Down Expand Up @@ -484,6 +512,34 @@ async def fragments():
[b"\x01\x02", b"\xfe\xff", b""],
)

async def test_send_async_fragmented_binary_from_str(self):
"""send sends a fragmented binary message from a str asynchronously."""

async def fragments():
yield "😀"
yield "😀"

await self.connection.send(fragments(), text=False)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀".encode(), "😀".encode(), b""],
)

async def test_send_async_fragmented_text_from_bytes(self):
"""send sends a fragmented text message from bytes asynchronously."""

async def fragments():
yield "😀".encode()
yield "😀".encode()

await self.connection.send(fragments(), text=True)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀", "😀", ""],
)

async def test_send_connection_closed_ok(self):
"""send raises ConnectionClosedOK after a normal closure."""
await self.remote_connection.close()
Expand Down
Loading