From 59d4dcf779fe7d2b0302083b072d8b03adce2f61 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Wed, 13 Nov 2024 23:00:22 +0100 Subject: [PATCH] Reintroduce InvalidMessage. This improves compatibility with the legacy implementation and clarifies error reporting. Fix #1548. --- docs/project/changelog.rst | 8 ++++++++ docs/reference/exceptions.rst | 4 ++-- src/websockets/__init__.py | 4 +++- src/websockets/asyncio/client.py | 6 ++++-- src/websockets/client.py | 6 +++++- src/websockets/exceptions.py | 11 +++++++++-- src/websockets/legacy/client.py | 3 ++- src/websockets/legacy/exceptions.py | 9 ++------- src/websockets/legacy/server.py | 3 ++- src/websockets/server.py | 6 +++++- tests/asyncio/test_client.py | 27 ++++++++++++++++++++------- tests/asyncio/test_server.py | 4 ++++ tests/legacy/test_exceptions.py | 4 ---- tests/sync/test_client.py | 21 ++++++++++++++++++--- tests/sync/test_server.py | 4 ++++ tests/test_client.py | 28 ++++++++++++++++++++++++---- tests/test_exceptions.py | 4 ++++ tests/test_server.py | 24 ++++++++++++++++++++---- 18 files changed, 136 insertions(+), 40 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 9c594b65..b7f4f62f 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -32,6 +32,14 @@ notice. *In development* +Bug fixes +......... + +* Wrapped errors when reading the opening handshake request or response in + :exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect` + raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening + handshake fails. + .. _14.1: 14.1 diff --git a/docs/reference/exceptions.rst b/docs/reference/exceptions.rst index 75934ef9..d6b7f0f5 100644 --- a/docs/reference/exceptions.rst +++ b/docs/reference/exceptions.rst @@ -30,6 +30,8 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs. .. autoexception:: InvalidHandshake +.. autoexception:: InvalidMessage + .. autoexception:: SecurityError .. autoexception:: InvalidStatus @@ -74,8 +76,6 @@ Legacy exceptions These exceptions are only used by the legacy :mod:`asyncio` implementation. -.. autoexception:: InvalidMessage - .. autoexception:: InvalidStatusCode .. autoexception:: AbortHandshake diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index 0c7e9b4c..c278b21d 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -31,6 +31,7 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", + "InvalidMessage", "InvalidOrigin", "InvalidParameterName", "InvalidParameterValue", @@ -71,6 +72,7 @@ InvalidHeader, InvalidHeaderFormat, InvalidHeaderValue, + InvalidMessage, InvalidOrigin, InvalidParameterName, InvalidParameterValue, @@ -122,6 +124,7 @@ "InvalidHeader": ".exceptions", "InvalidHeaderFormat": ".exceptions", "InvalidHeaderValue": ".exceptions", + "InvalidMessage": ".exceptions", "InvalidOrigin": ".exceptions", "InvalidParameterName": ".exceptions", "InvalidParameterValue": ".exceptions", @@ -159,7 +162,6 @@ "WebSocketClientProtocol": ".legacy.client", # .legacy.exceptions "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index cdd9bfac..8581c055 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -11,7 +11,7 @@ from ..client import ClientProtocol, backoff from ..datastructures import HeadersLike -from ..exceptions import InvalidStatus, SecurityError +from ..exceptions import InvalidMessage, InvalidStatus, SecurityError from ..extensions.base import ClientExtensionFactory from ..extensions.permessage_deflate import enable_client_permessage_deflate from ..headers import validate_subprotocols @@ -147,7 +147,9 @@ def process_exception(exc: Exception) -> Exception | None: That exception will be raised, breaking out of the retry loop. """ - if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)): + if isinstance(exc, (OSError, asyncio.TimeoutError)): + return None + if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError): return None if isinstance(exc, InvalidStatus) and exc.response.status_code in [ 500, # Internal Server Error diff --git a/src/websockets/client.py b/src/websockets/client.py index f6cbc9f6..5ced05c2 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -11,6 +11,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, + InvalidMessage, InvalidStatus, InvalidUpgrade, NegotiationError, @@ -318,7 +319,10 @@ def parse(self) -> Generator[None]: self.reader.read_to_eof, ) except Exception as exc: - self.handshake_exc = exc + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP response" + ) + self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index f3e75197..81fbb1ef 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -8,7 +8,7 @@ * :exc:`InvalidURI` * :exc:`InvalidHandshake` * :exc:`SecurityError` - * :exc:`InvalidMessage` (legacy) + * :exc:`InvalidMessage` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` @@ -48,6 +48,7 @@ "InvalidHeader", "InvalidHeaderFormat", "InvalidHeaderValue", + "InvalidMessage", "InvalidOrigin", "InvalidUpgrade", "NegotiationError", @@ -185,6 +186,13 @@ class SecurityError(InvalidHandshake): """ +class InvalidMessage(InvalidHandshake): + """ + Raised when a handshake request or response is malformed. + + """ + + class InvalidStatus(InvalidHandshake): """ Raised when a handshake response rejects the WebSocket upgrade. @@ -410,7 +418,6 @@ class ConcurrencyError(WebSocketException, RuntimeError): deprecated_aliases={ # deprecated in 14.0 - 2024-11-09 "AbortHandshake": ".legacy.exceptions", - "InvalidMessage": ".legacy.exceptions", "InvalidStatusCode": ".legacy.exceptions", "RedirectHandshake": ".legacy.exceptions", "WebSocketProtocolError": ".legacy.exceptions", diff --git a/src/websockets/legacy/client.py b/src/websockets/legacy/client.py index a3856b47..29141f39 100644 --- a/src/websockets/legacy/client.py +++ b/src/websockets/legacy/client.py @@ -17,6 +17,7 @@ from ..exceptions import ( InvalidHeader, InvalidHeaderValue, + InvalidMessage, NegotiationError, SecurityError, ) @@ -34,7 +35,7 @@ from ..http11 import USER_AGENT from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol from ..uri import WebSocketURI, parse_uri -from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake +from .exceptions import InvalidStatusCode, RedirectHandshake from .handshake import build_request, check_response from .http import read_response from .protocol import WebSocketCommonProtocol diff --git a/src/websockets/legacy/exceptions.py b/src/websockets/legacy/exceptions.py index e2279c82..78fb696f 100644 --- a/src/websockets/legacy/exceptions.py +++ b/src/websockets/legacy/exceptions.py @@ -3,18 +3,13 @@ from .. import datastructures from ..exceptions import ( InvalidHandshake, + # InvalidMessage was incorrectly moved here in versions 14.0 and 14.1. + InvalidMessage, # noqa: F401 ProtocolError as WebSocketProtocolError, # noqa: F401 ) from ..typing import StatusLike -class InvalidMessage(InvalidHandshake): - """ - Raised when a handshake request or response is malformed. - - """ - - class InvalidStatusCode(InvalidHandshake): """ Raised when a handshake response status code is invalid. diff --git a/src/websockets/legacy/server.py b/src/websockets/legacy/server.py index 9326b610..f9d57cb9 100644 --- a/src/websockets/legacy/server.py +++ b/src/websockets/legacy/server.py @@ -17,6 +17,7 @@ from ..exceptions import ( InvalidHandshake, InvalidHeader, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -32,7 +33,7 @@ from ..http11 import SERVER from ..protocol import State from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol -from .exceptions import AbortHandshake, InvalidMessage +from .exceptions import AbortHandshake from .handshake import build_response, check_request from .http import read_request from .protocol import WebSocketCommonProtocol, broadcast diff --git a/src/websockets/server.py b/src/websockets/server.py index 607cc306..1b663a13 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -13,6 +13,7 @@ InvalidHandshake, InvalidHeader, InvalidHeaderValue, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -552,7 +553,10 @@ def parse(self) -> Generator[None]: self.reader.read_line, ) except Exception as exc: - self.handshake_exc = exc + self.handshake_exc = InvalidMessage( + "did not receive a valid HTTP request" + ) + self.handshake_exc.__cause__ = exc self.send_eof() self.parser = self.discard() next(self.parser) # start coroutine diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 231d6b8c..1773c08b 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -12,6 +12,7 @@ from websockets.client import backoff from websockets.exceptions import ( InvalidHandshake, + InvalidMessage, InvalidStatus, InvalidURI, SecurityError, @@ -151,22 +152,24 @@ async def test_reconnect(self): iterations = 0 successful = 0 - def process_request(connection, request): + async def process_request(connection, request): nonlocal iterations iterations += 1 # Retriable errors if iterations == 1: - connection.transport.close() + await asyncio.sleep(3 * MS) elif iterations == 2: + connection.transport.close() + elif iterations == 3: return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒") # Fatal error - elif iterations == 5: + elif iterations == 6: return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸") async with serve(*args, process_request=process_request) as server: with self.assertRaises(InvalidStatus) as raised: async with short_backoff_delay(): - async for client in connect(get_uri(server)): + async for client in connect(get_uri(server), open_timeout=3 * MS): self.assertEqual(client.protocol.state.name, "OPEN") successful += 1 @@ -174,7 +177,7 @@ def process_request(connection, request): str(raised.exception), "server rejected WebSocket connection: HTTP 402", ) - self.assertEqual(iterations, 5) + self.assertEqual(iterations, 6) self.assertEqual(successful, 2) async def test_reconnect_with_custom_process_exception(self): @@ -393,11 +396,16 @@ def close_connection(self, request): self.close_transport() async with serve(*args, process_request=close_connection) as server: - with self.assertRaises(EOFError) as raised: + with self.assertRaises(InvalidMessage) as raised: async with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) @@ -443,11 +451,16 @@ async def junk(reader, writer): server = await asyncio.start_server(junk, "localhost", 0) host, port = get_host_port(server) async with server: - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMessage) as raised: async with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/asyncio/test_server.py b/tests/asyncio/test_server.py index 3e289e59..83885fab 100644 --- a/tests/asyncio/test_server.py +++ b/tests/asyncio/test_server.py @@ -473,6 +473,10 @@ async def test_junk_handshake(self): ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) diff --git a/tests/legacy/test_exceptions.py b/tests/legacy/test_exceptions.py index e5d22a91..4e6ff952 100644 --- a/tests/legacy/test_exceptions.py +++ b/tests/legacy/test_exceptions.py @@ -7,10 +7,6 @@ class ExceptionsTests(unittest.TestCase): def test_str(self): for exception, exception_str in [ - ( - InvalidMessage("malformed HTTP message"), - "malformed HTTP message", - ), ( InvalidStatusCode(403, Headers()), "server rejected WebSocket connection: HTTP 403", diff --git a/tests/sync/test_client.py b/tests/sync/test_client.py index 9d457a91..7d817051 100644 --- a/tests/sync/test_client.py +++ b/tests/sync/test_client.py @@ -7,7 +7,12 @@ import time import unittest -from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI +from websockets.exceptions import ( + InvalidHandshake, + InvalidMessage, + InvalidStatus, + InvalidURI, +) from websockets.extensions.permessage_deflate import PerMessageDeflate from websockets.sync.client import * @@ -149,11 +154,16 @@ def close_connection(self, request): self.close_socket() with run_server(process_request=close_connection) as server: - with self.assertRaises(EOFError) as raised: + with self.assertRaises(InvalidMessage) as raised: with connect(get_uri(server)): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, EOFError) + self.assertEqual( + str(raised.exception.__cause__), "connection closed while reading HTTP status line", ) @@ -203,11 +213,16 @@ def handle(self): thread = threading.Thread(target=server.serve_forever, args=(MS,)) thread.start() try: - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMessage) as raised: with connect(f"ws://{host}:{port}"): self.fail("did not raise") self.assertEqual( str(raised.exception), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(raised.exception.__cause__, ValueError) + self.assertEqual( + str(raised.exception.__cause__), "unsupported protocol; expected HTTP/1.1: " "220 smtp.invalid ESMTP Postfix", ) diff --git a/tests/sync/test_server.py b/tests/sync/test_server.py index 54e49bf1..9a267643 100644 --- a/tests/sync/test_server.py +++ b/tests/sync/test_server.py @@ -311,6 +311,10 @@ def test_junk_handshake(self): ) self.assertEqual( [str(record.exc_info[1]) for record in logs.records], + ["did not receive a valid HTTP request"], + ) + self.assertEqual( + [str(record.exc_info[1].__cause__) for record in logs.records], ["invalid HTTP request line: HELO relay.invalid"], ) diff --git a/tests/test_client.py b/tests/test_client.py index 2468be85..1edbae57 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,7 +8,12 @@ from websockets.client import * from websockets.client import backoff from websockets.datastructures import Headers -from websockets.exceptions import InvalidHandshake, InvalidHeader, InvalidStatus +from websockets.exceptions import ( + InvalidHandshake, + InvalidHeader, + InvalidMessage, + InvalidStatus, +) from websockets.frames import OP_TEXT, Frame from websockets.http11 import Request, Response from websockets.protocol import CONNECTING, OPEN @@ -244,9 +249,14 @@ def test_receive_no_response(self, _generate_key): client.receive_eof() self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, EOFError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(client.handshake_exc.__cause__), "connection closed while reading HTTP status line", ) @@ -257,9 +267,14 @@ def test_receive_truncated_response(self, _generate_key): client.receive_eof() self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, EOFError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(client.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) @@ -272,9 +287,14 @@ def test_receive_random_response(self, _generate_key): client.receive_data(b"250 Ok\r\n") self.assertEqual(client.events_received(), []) - self.assertIsInstance(client.handshake_exc, ValueError) + self.assertIsInstance(client.handshake_exc, InvalidMessage) self.assertEqual( str(client.handshake_exc), + "did not receive a valid HTTP response", + ) + self.assertIsInstance(client.handshake_exc.__cause__, ValueError) + self.assertEqual( + str(client.handshake_exc.__cause__), "invalid HTTP status line: 220 smtp.invalid", ) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fef41d13..e0518b0e 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -91,6 +91,10 @@ def test_str(self): SecurityError("redirect from WSS to WS"), "redirect from WSS to WS", ), + ( + InvalidMessage("malformed HTTP message"), + "malformed HTTP message", + ), ( InvalidStatus(Response(401, "Unauthorized", Headers())), "server rejected WebSocket connection: HTTP 401", diff --git a/tests/test_server.py b/tests/test_server.py index 844ba64e..5efeca2d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -7,6 +7,7 @@ from websockets.datastructures import Headers from websockets.exceptions import ( InvalidHeader, + InvalidMessage, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -207,9 +208,15 @@ def test_receive_no_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, EOFError) + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(server.handshake_exc.__cause__), "connection closed while reading HTTP request line", ) @@ -220,9 +227,14 @@ def test_receive_truncated_request(self): server.receive_eof() self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, EOFError) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, EOFError) + self.assertEqual( + str(server.handshake_exc.__cause__), "connection closed while reading HTTP headers", ) @@ -233,10 +245,14 @@ def test_receive_junk_request(self): server.receive_data(b"MAIL FROM: \r\n") server.receive_data(b"RCPT TO: \r\n") - self.assertEqual(server.events_received(), []) - self.assertIsInstance(server.handshake_exc, ValueError) + self.assertIsInstance(server.handshake_exc, InvalidMessage) self.assertEqual( str(server.handshake_exc), + "did not receive a valid HTTP request", + ) + self.assertIsInstance(server.handshake_exc.__cause__, ValueError) + self.assertEqual( + str(server.handshake_exc.__cause__), "invalid HTTP request line: HELO relay.invalid", )