From c4b3d2bad27ff5f970b3625e49f82695b19ca212 Mon Sep 17 00:00:00 2001 From: Benjamin Thomas Schwertfeger Date: Thu, 28 Nov 2024 18:11:05 +0100 Subject: [PATCH] add a unit test --- src/websockets/asyncio/messages.py | 1 + tests/asyncio/test_connection.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 5a7bfbb8..936adf63 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -292,6 +292,7 @@ def prepare_close(self) -> None: # Resuming the writer to avoid deadlocks if self.paused: + self.paused = False self.resume() def close(self) -> None: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5a0b61bf..95f4e56a 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -806,6 +806,27 @@ async def test_close_preserves_queued_messages(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) + async def test_close_preserves_queued_messages_gt_max_queue(self): + """ + close preserves messages buffered in the assembler, even if they + exceed the default buffer size. + """ + + for _ in range(100): + await self.remote_connection.send("😀") + + await self.connection.close() + + for _ in range(100): + self.assertEqual(await self.connection.recv(), "😀") + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close()