diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 702e6999..12871e4b 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -251,12 +251,13 @@ async def recv(self, decode: bool | None = None) -> Data: You may override this behavior with the ``decode`` argument: - * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames - and return a bytestring (:class:`bytes`). This may be useful to - optimize performance when decoding isn't needed. + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames - and return a string (:class:`str`). This is useful for servers - that send binary frames instead of text frames. + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. Raises: ConnectionClosed: When the connection is closed. @@ -333,7 +334,11 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data "is already running recv or recv_streaming" ) from None - async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> None: + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -344,6 +349,17 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable or an asynchronous iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. @@ -393,12 +409,20 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No # strings and bytes-like objects are iterable. if isinstance(message, str): - async with self.send_context(): - self.protocol.send_text(message.encode()) + if text is False: + async with self.send_context(): + self.protocol.send_binary(message.encode()) + else: + async with self.send_context(): + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - async with self.send_context(): - self.protocol.send_binary(message) + if text is True: + async with self.send_context(): + self.protocol.send_text(message) + else: + async with self.send_context(): + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -419,36 +443,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("iterable must contain uniform types") @@ -481,36 +501,32 @@ async def send(self, message: Data | Iterable[Data] | AsyncIterable[Data]) -> No try: # First fragment. if isinstance(chunk, str): - text = True - async with self.send_context(): - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False - async with self.send_context(): - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("async iterable must contain bytes or str") # Other fragments async for chunk in achunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: async with self.send_context(): - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: async with self.send_context(): - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("async iterable must contain uniform types") diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 70d9dad6..563cf2b1 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -190,13 +190,13 @@ async def test_recv_binary(self): await self.remote_connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") - async def test_recv_encoded_text(self): - """recv receives an UTF-8 encoded text message.""" + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) - async def test_recv_decoded_binary(self): - """recv receives an UTF-8 decoded binary message.""" + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual(await self.connection.recv(decode=True), "😀") @@ -304,16 +304,16 @@ async def test_recv_streaming_binary(self): [b"\x01\x02\xfe\xff"], ) - async def test_recv_streaming_encoded_text(self): - """recv_streaming receives an UTF-8 encoded text message.""" + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" await self.remote_connection.send("😀") self.assertEqual( await alist(self.connection.recv_streaming(decode=False)), ["😀".encode()], ) - async def test_recv_streaming_decoded_binary(self): - """recv_streaming receives a UTF-8 decoded binary message.""" + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" await self.remote_connection.send("😀".encode()) self.assertEqual( await alist(self.connection.recv_streaming(decode=True)), @@ -438,6 +438,16 @@ async def test_send_binary(self): await self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + async def test_send_fragmented_text(self): """send sends a fragmented text message.""" await self.connection.send(["😀", "😀"]) @@ -456,6 +466,24 @@ async def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_async_fragmented_text(self): """send sends a fragmented text message asynchronously.""" @@ -484,6 +512,34 @@ async def fragments(): [b"\x01\x02", b"\xfe\xff", b""], ) + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + async def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" await self.remote_connection.close()