From 50b6d20d7a652d39cffc7aea9f8c0abc88fb8f37 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 11 Feb 2024 21:19:24 +0100 Subject: [PATCH] Various cleanups in sync implementation. --- src/websockets/sync/client.py | 9 +++-- src/websockets/sync/connection.py | 58 +++++++++++++++---------------- src/websockets/sync/server.py | 27 +++++++------- 3 files changed, 45 insertions(+), 49 deletions(-) diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index 6faca778..0bb7a76f 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -25,7 +25,7 @@ class ClientConnection(Connection): """ - Threaded implementation of a WebSocket client connection. + :mod:`threading` implementation of a WebSocket client connection. :class:`ClientConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -157,7 +157,7 @@ def connect( :func:`connect` may be used as a context manager:: - async with websockets.sync.client.connect(...) as websocket: + with websockets.sync.client.connect(...) as websocket: ... The connection is closed automatically when exiting the context. @@ -273,19 +273,18 @@ def connect( sock = ssl.wrap_socket(sock, server_hostname=server_hostname) sock.settimeout(None) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ClientProtocol( wsuri, origin=origin, extensions=extensions, subprotocols=subprotocols, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection connection = create_connection( sock, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 62aa17ff..6ac40cd7 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -21,12 +21,10 @@ __all__ = ["Connection"] -logger = logging.getLogger(__name__) - class Connection: """ - Threaded implementation of a WebSocket connection. + :mod:`threading` implementation of a WebSocket connection. :class:`Connection` provides APIs shared between WebSocket servers and clients. @@ -82,7 +80,7 @@ def __init__( self.close_deadline: Optional[Deadline] = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pings: Dict[bytes, threading.Event] = {} + self.ping_waiters: Dict[bytes, threading.Event] = {} # Receiving events from the socket. self.recv_events_thread = threading.Thread(target=self.recv_events) @@ -90,7 +88,7 @@ def __init__( # Exception raised in recv_events, to be chained to ConnectionClosed # in the user thread in order to show why the TCP connection dropped. - self.recv_events_exc: Optional[BaseException] = None + self.recv_exc: Optional[BaseException] = None # Public attributes @@ -198,7 +196,7 @@ def recv(self, timeout: Optional[float] = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv while another thread " @@ -229,9 +227,10 @@ def recv_streaming(self) -> Iterator[Data]: """ try: - yield from self.recv_messages.get_iter() + for frame in self.recv_messages.get_iter(): + yield frame except EOFError: - raise self.protocol.close_exc from self.recv_events_exc + raise self.protocol.close_exc from self.recv_exc except RuntimeError: raise RuntimeError( "cannot call recv_streaming while another thread " @@ -273,7 +272,7 @@ def send(self, message: Union[Data, Iterable[Data]]) -> None: Raises: ConnectionClosed: When the connection is closed. - RuntimeError: If a connection is busy sending a fragmented message. + RuntimeError: If the connection is sending a fragmented message. TypeError: If ``message`` doesn't have a supported type. """ @@ -449,15 +448,15 @@ def ping(self, data: Optional[Data] = None) -> threading.Event: with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pings: + if data in self.ping_waiters: raise RuntimeError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pings: + while data is None or data in self.ping_waiters: data = struct.pack("!I", random.getrandbits(32)) pong_waiter = threading.Event() - self.pings[data] = pong_waiter + self.ping_waiters[data] = pong_waiter self.protocol.send_ping(data) return pong_waiter @@ -504,22 +503,22 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pings: + if data not in self.ping_waiters: return # Sending a pong for only the most recent ping is legal. # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, ping in self.pings.items(): + for ping_id, ping in self.ping_waiters.items(): ping_ids.append(ping_id) ping.set() if ping_id == data: break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pings. + # Remove acknowledged pings from self.ping_waiters. for ping_id in ping_ids: - del self.pings[ping_id] + del self.ping_waiters[ping_id] def recv_events(self) -> None: """ @@ -541,10 +540,10 @@ def recv_events(self) -> None: self.logger.debug("error while receiving data", exc_info=True) # When the closing handshake is initiated by our side, # recv() may block until send_context() closes the socket. - # In that case, send_context() already set recv_events_exc. - # Calling set_recv_events_exc() avoids overwriting it. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if data == b"": @@ -552,7 +551,7 @@ def recv_events(self) -> None: # Acquire the connection lock. with self.protocol_mutex: - # Feed incoming data to the connection. + # Feed incoming data to the protocol. self.protocol.receive_data(data) # This isn't expected to raise an exception. @@ -568,7 +567,7 @@ def recv_events(self) -> None: # set by send_context(), in case of a race condition # i.e. send_context() closes the socket after recv() # returns above but before send_data() calls send(). - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) break if self.protocol.close_expected(): @@ -595,7 +594,7 @@ def recv_events(self) -> None: # Breaking out of the while True: ... loop means that we believe # that the socket doesn't work anymore. with self.protocol_mutex: - # Feed the end of the data stream to the connection. + # Feed the end of the data stream to the protocol. self.protocol.receive_eof() # This isn't expected to generate events. @@ -609,7 +608,7 @@ def recv_events(self) -> None: # This branch should never run. It's a safety net in case of bugs. self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: - self.set_recv_events_exc(exc) + self.set_recv_exc(exc) # We don't know where we crashed. Force protocol state to CLOSED. self.protocol.state = CLOSED finally: @@ -668,7 +667,6 @@ def send_context( wait_for_close = True # If the connection is expected to close soon, set the # close deadline based on the close timeout. - # Since we tested earlier that protocol.state was OPEN # (or CONNECTING) and we didn't release protocol_mutex, # it is certain that self.close_deadline is still None. @@ -710,11 +708,11 @@ def send_context( # original_exc is never set when wait_for_close is True. assert original_exc is None original_exc = TimeoutError("timed out while closing connection") - # Set recv_events_exc before closing the socket in order to get + # Set recv_exc before closing the socket in order to get # proper exception reporting. raise_close_exc = True with self.protocol_mutex: - self.set_recv_events_exc(original_exc) + self.set_recv_exc(original_exc) # If an error occurred, close the socket to terminate the connection and # raise an exception. @@ -745,16 +743,16 @@ def send_data(self) -> None: except OSError: # socket already closed pass - def set_recv_events_exc(self, exc: Optional[BaseException]) -> None: + def set_recv_exc(self, exc: Optional[BaseException]) -> None: """ - Set recv_events_exc, if not set yet. + Set recv_exc, if not set yet. This method requires holding protocol_mutex. """ assert self.protocol_mutex.locked() - if self.recv_events_exc is None: - self.recv_events_exc = exc + if self.recv_exc is None: + self.recv_exc = exc def close_socket(self) -> None: """ diff --git a/src/websockets/sync/server.py b/src/websockets/sync/server.py index fa6087d5..a070edf1 100644 --- a/src/websockets/sync/server.py +++ b/src/websockets/sync/server.py @@ -30,7 +30,7 @@ class ServerConnection(Connection): """ - Threaded implementation of a WebSocket server connection. + :mod:`threading` implementation of a WebSocket server connection. :class:`ServerConnection` provides :meth:`recv` and :meth:`send` methods for receiving and sending messages. @@ -188,6 +188,8 @@ class WebSocketServer: handler: Handler for one connection. Receives the socket and address returned by :meth:`~socket.socket.accept`. logger: Logger for this server. + It defaults to ``logging.getLogger("websockets.server")``. + See the :doc:`logging guide <../../topics/logging>` for details. """ @@ -311,16 +313,16 @@ def serve( Whenever a client connects, the server creates a :class:`ServerConnection`, performs the opening handshake, and delegates to the ``handler``. - The handler receives a :class:`ServerConnection` instance, which you can use - to send and receive messages. + The handler receives the :class:`ServerConnection` instance, which you can + use to send and receive messages. Once the handler completes, either normally or with an exception, the server performs the closing handshake and closes the connection. - :class:`WebSocketServer` mirrors the API of + This function returns a :class:`WebSocketServer` whose API mirrors :class:`~socketserver.BaseServer`. Treat it as a context manager to ensure - that it will be closed and call the :meth:`~WebSocketServer.serve_forever` - method to serve requests:: + that it will be closed and call :meth:`~WebSocketServer.serve_forever` to + serve requests:: def handler(websocket): ... @@ -454,15 +456,13 @@ def conn_handler(sock: socket.socket, addr: Any) -> None: sock.do_handshake() sock.settimeout(None) - # Create a closure so that select_subprotocol has access to self. - + # Create a closure to give select_subprotocol access to connection. protocol_select_subprotocol: Optional[ Callable[ [ServerProtocol, Sequence[Subprotocol]], Optional[Subprotocol], ] ] = None - if select_subprotocol is not None: def protocol_select_subprotocol( @@ -475,19 +475,18 @@ def protocol_select_subprotocol( assert protocol is connection.protocol return select_subprotocol(connection, subprotocols) - # Initialize WebSocket connection + # Initialize WebSocket protocol protocol = ServerProtocol( origins=origins, extensions=extensions, subprotocols=subprotocols, select_subprotocol=protocol_select_subprotocol, - state=CONNECTING, max_size=max_size, logger=logger, ) - # Initialize WebSocket protocol + # Initialize WebSocket connection assert create_connection is not None # help mypy connection = create_connection( @@ -522,7 +521,7 @@ def protocol_select_subprotocol( def unix_serve( - handler: Callable[[ServerConnection], Any], + handler: Callable[[ServerConnection], None], path: Optional[str] = None, **kwargs: Any, ) -> WebSocketServer: @@ -541,4 +540,4 @@ def unix_serve( path: File system path to the Unix socket. """ - return serve(handler, path=path, unix=True, **kwargs) + return serve(handler, unix=True, path=path, **kwargs)