Skip to content

Commit

Permalink
Add option to force sending text or binary frames.
Browse files Browse the repository at this point in the history
Fix #1515.
  • Loading branch information
aaugustin committed Sep 30, 2024
1 parent 21987f9 commit bc4b8f2
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 62 deletions.
124 changes: 70 additions & 54 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data:
You may override this behavior with the ``decode`` argument:
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames
and return a bytestring (:class:`bytes`). This may be useful to
optimize performance when decoding isn't needed.
* Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and
return a bytestring (:class:`bytes`). This improves performance
when decoding isn't needed, for example if the message contains
JSON and you're using a JSON library that expects a bytestring.
* Set ``decode=True`` to force UTF-8 decoding of Binary_ frames
and return a string (:class:`str`). This is useful for servers
that send binary frames instead of text frames.
and return a string (:class:`str`). This may be useful for
servers that send binary frames instead of text frames.
Raises:
ConnectionClosed: When the connection is closed.
Expand Down Expand Up @@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data
"is already running recv or recv_streaming"
) from None

async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None:
async def send(
self,
message: Data | Iterable[Data] | AsyncIterable[Data],
text: bool | None = None,
) -> None:
"""
Send a message.
Expand All @@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
You may override this behavior with the ``text`` argument:
* Set ``text=True`` to send a bytestring or bytes-like object
(:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a
Text_ frame. This improves performance when the message is already
UTF-8 encoded, for example if the message contains JSON and you're
using a JSON library that produces a bytestring.
* Set ``text=False`` to send a string (:class:`str`) in a Binary_
frame. This may be useful for servers that expect binary frames
instead of text frames.
:meth:`send` also accepts an iterable or an asynchronous iterable of
strings, bytestrings, or bytes-like objects to enable fragmentation_.
Each item is treated as a message fragment and sent in its own frame.
Expand Down Expand Up @@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
# strings and bytes-like objects are iterable.

if isinstance(message, str):
async with self.send_context():
self.protocol.send_text(message.encode())
if text is False:
async with self.send_context():
self.protocol.send_binary(message.encode())
else:
async with self.send_context():
self.protocol.send_text(message.encode())

elif isinstance(message, BytesLike):
async with self.send_context():
self.protocol.send_binary(message)
if text is True:
async with self.send_context():
self.protocol.send_text(message)
else:
async with self.send_context():
self.protocol.send_binary(message)

# Catch a common mistake -- passing a dict to send().

Expand All @@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
try:
# First fragment.
if isinstance(chunk, str):
text = True
async with self.send_context():
self.protocol.send_text(
chunk.encode(),
fin=False,
)
if text is False:
async with self.send_context():
self.protocol.send_binary(chunk.encode(), fin=False)
else:
async with self.send_context():
self.protocol.send_text(chunk.encode(), fin=False)
encode = True
elif isinstance(chunk, BytesLike):
text = False
async with self.send_context():
self.protocol.send_binary(
chunk,
fin=False,
)
if text is True:
async with self.send_context():
self.protocol.send_text(chunk, fin=False)
else:
async with self.send_context():
self.protocol.send_binary(chunk, fin=False)
encode = False
else:
raise TypeError("iterable must contain bytes or str")

# Other fragments
for chunk in chunks:
if isinstance(chunk, str) and text:
if isinstance(chunk, str) and encode:
async with self.send_context():
self.protocol.send_continuation(
chunk.encode(),
fin=False,
)
elif isinstance(chunk, BytesLike) and not text:
self.protocol.send_continuation(chunk.encode(), fin=False)
elif isinstance(chunk, BytesLike) and not encode:
async with self.send_context():
self.protocol.send_continuation(
chunk,
fin=False,
)
self.protocol.send_continuation(chunk, fin=False)
else:
raise TypeError("iterable must contain uniform types")

Expand Down Expand Up @@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No
try:
# First fragment.
if isinstance(chunk, str):
text = True
async with self.send_context():
self.protocol.send_text(
chunk.encode(),
fin=False,
)
if text is False:
async with self.send_context():
self.protocol.send_binary(chunk.encode(), fin=False)
else:
async with self.send_context():
self.protocol.send_text(chunk.encode(), fin=False)
encode = True
elif isinstance(chunk, BytesLike):
text = False
async with self.send_context():
self.protocol.send_binary(
chunk,
fin=False,
)
if text is True:
async with self.send_context():
self.protocol.send_text(chunk, fin=False)
else:
async with self.send_context():
self.protocol.send_binary(chunk, fin=False)
encode = False
else:
raise TypeError("async iterable must contain bytes or str")

# Other fragments
async for chunk in achunks:
if isinstance(chunk, str) and text:
if isinstance(chunk, str) and encode:
async with self.send_context():
self.protocol.send_continuation(
chunk.encode(),
fin=False,
)
elif isinstance(chunk, BytesLike) and not text:
self.protocol.send_continuation(chunk.encode(), fin=False)
elif isinstance(chunk, BytesLike) and not encode:
async with self.send_context():
self.protocol.send_continuation(
chunk,
fin=False,
)
self.protocol.send_continuation(chunk, fin=False)
else:
raise TypeError("async iterable must contain uniform types")

Expand Down
72 changes: 64 additions & 8 deletions tests/asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,13 @@ async def test_recv_binary(self):
await self.remote_connection.send(b"\x01\x02\xfe\xff")
self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff")

async def test_recv_encoded_text(self):
"""recv receives an UTF-8 encoded text message."""
async def test_recv_text_as_bytes(self):
"""recv receives a text message as bytes."""
await self.remote_connection.send("😀")
self.assertEqual(await self.connection.recv(decode=False), "😀".encode())

async def test_recv_decoded_binary(self):
"""recv receives an UTF-8 decoded binary message."""
async def test_recv_binary_as_text(self):
"""recv receives a binary message as a str."""
await self.remote_connection.send("😀".encode())
self.assertEqual(await self.connection.recv(decode=True), "😀")

Expand Down Expand Up @@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self):
[b"\x01\x02\xfe\xff"],
)

async def test_recv_streaming_encoded_text(self):
"""recv_streaming receives an UTF-8 encoded text message."""
async def test_recv_streaming_text_as_bytes(self):
"""recv_streaming receives a text message as bytes."""
await self.remote_connection.send("😀")
self.assertEqual(
await alist(self.connection.recv_streaming(decode=False)),
["😀".encode()],
)

async def test_recv_streaming_decoded_binary(self):
"""recv_streaming receives a UTF-8 decoded binary message."""
async def test_recv_streaming_binary_as_str(self):
"""recv_streaming receives a binary message as a str."""
await self.remote_connection.send("😀".encode())
self.assertEqual(
await alist(self.connection.recv_streaming(decode=True)),
Expand Down Expand Up @@ -438,6 +438,16 @@ async def test_send_binary(self):
await self.connection.send(b"\x01\x02\xfe\xff")
self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff")

async def test_send_binary_from_str(self):
"""send sends a binary message from a str."""
await self.connection.send("😀", text=False)
self.assertEqual(await self.remote_connection.recv(), "😀".encode())

async def test_send_text_from_bytes(self):
"""send sends a text message from bytes."""
await self.connection.send("😀".encode(), text=True)
self.assertEqual(await self.remote_connection.recv(), "😀")

async def test_send_fragmented_text(self):
"""send sends a fragmented text message."""
await self.connection.send(["😀", "😀"])
Expand All @@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self):
[b"\x01\x02", b"\xfe\xff", b""],
)

async def test_send_fragmented_binary_from_str(self):
"""send sends a fragmented binary message from a str."""
await self.connection.send(["😀", "😀"], text=False)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀".encode(), "😀".encode(), b""],
)

async def test_send_fragmented_text_from_bytes(self):
"""send sends a fragmented text message from bytes."""
await self.connection.send(["😀".encode(), "😀".encode()], text=True)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀", "😀", ""],
)

async def test_send_async_fragmented_text(self):
"""send sends a fragmented text message asynchronously."""

Expand Down Expand Up @@ -484,6 +512,34 @@ async def fragments():
[b"\x01\x02", b"\xfe\xff", b""],
)

async def test_send_async_fragmented_binary_from_str(self):
"""send sends a fragmented binary message from a str asynchronously."""

async def fragments():
yield "😀"
yield "😀"

await self.connection.send(fragments(), text=False)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀".encode(), "😀".encode(), b""],
)

async def test_send_async_fragmented_text_from_bytes(self):
"""send sends a fragmented text message from bytes asynchronously."""

async def fragments():
yield "😀".encode()
yield "😀".encode()

await self.connection.send(fragments(), text=True)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
await alist(self.remote_connection.recv_streaming()),
["😀", "😀", ""],
)

async def test_send_connection_closed_ok(self):
"""send raises ConnectionClosedOK after a normal closure."""
await self.remote_connection.close()
Expand Down

0 comments on commit bc4b8f2

Please sign in to comment.