From 5ab3adf19f86cf06b41078b3046661cf34dc6194 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Tue, 10 Dec 2024 10:16:53 -0500 Subject: [PATCH 1/5] bump websockets version --- compat-tests | 2 +- requirements-client.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/compat-tests b/compat-tests index 9b5fc44426b6..3c5ec0111e2a 160000 --- a/compat-tests +++ b/compat-tests @@ -1 +1 @@ -Subproject commit 9b5fc44426b6a98a05408106fd6b5453ae9a0c76 +Subproject commit 3c5ec0111e2aa7b160f2b21cfd383d19448dfe13 diff --git a/requirements-client.txt b/requirements-client.txt index de5e2b5ab1e5..94d912a834a5 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -35,4 +35,4 @@ toml >= 0.10.0 typing_extensions >= 4.5.0, < 5.0.0 ujson >= 5.8.0, < 6.0.0 uvicorn >=0.14.0, !=0.29.0 -websockets >= 10.4, < 14.0 \ No newline at end of file +websockets >= 14.1, < 15.0 \ No newline at end of file From 4a542f84d4a75463d574bd35b5cd32eddce99a46 Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Tue, 10 Dec 2024 10:21:42 -0500 Subject: [PATCH 2/5] bump websockets to >= 14.1 --- src/prefect/events/clients.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/prefect/events/clients.py b/src/prefect/events/clients.py index bd09eb3ab20c..18aa704bbff0 100644 --- a/src/prefect/events/clients.py +++ b/src/prefect/events/clients.py @@ -20,8 +20,8 @@ from cachetools import TTLCache from prometheus_client import Counter from typing_extensions import Self -from websockets import Subprotocol -from websockets.client import WebSocketClientProtocol, connect +from websockets import Subprotocol, connect +from websockets.asyncio.client import ClientConnection from websockets.exceptions import ( ConnectionClosed, ConnectionClosedError, @@ -241,7 +241,7 @@ def _get_api_url_and_key( class PrefectEventsClient(EventsClient): """A Prefect Events client that streams events to a Prefect server""" - _websocket: Optional[WebSocketClientProtocol] + _websocket: Optional[ClientConnection] _unconfirmed_events: List[Event] def __init__( @@ -437,7 +437,7 @@ def __init__( ) self._connect = connect( self._events_socket_url, - extra_headers={"Authorization": f"bearer {api_key}"}, + additional_headers={"Authorization": f"bearer {api_key}"}, ) @@ -462,7 +462,7 @@ class PrefectEventSubscriber: """ - _websocket: Optional[WebSocketClientProtocol] + _websocket: Optional[ClientConnection] _filter: "EventFilter" _seen_events: MutableMapping[UUID, bool] From d8cb37c05bca4874c8fa54cd853754838cfcc5bb Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Tue, 10 Dec 2024 10:39:24 -0500 Subject: [PATCH 3/5] fix subscriptions --- src/prefect/client/subscriptions.py | 2 +- src/prefect/task_worker.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/prefect/client/subscriptions.py b/src/prefect/client/subscriptions.py index 8e04b3735e8a..9506b3ac720a 100644 --- a/src/prefect/client/subscriptions.py +++ b/src/prefect/client/subscriptions.py @@ -81,7 +81,7 @@ async def _ensure_connected(self): ) auth: dict[str, Any] = orjson.loads(await websocket.recv()) - assert auth["type"] == "auth_success", auth.get("message") + assert auth["type"] == "auth_success", auth.get("reason") message: dict[str, Any] = {"type": "subscribe", "keys": self.keys} if self.client_id: diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index 6c75bd31f98f..69dc26cf0b8f 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -16,7 +16,7 @@ import uvicorn from exceptiongroup import BaseExceptionGroup # novermin from fastapi import FastAPI -from websockets.exceptions import InvalidStatusCode +from websockets.exceptions import InvalidStatus from prefect import Task from prefect._internal.concurrency.api import create_call, from_sync @@ -163,8 +163,8 @@ async def start(self) -> None: logger.info("Starting task worker...") try: await self._subscribe_to_task_scheduling() - except InvalidStatusCode as exc: - if exc.status_code == 403: + except InvalidStatus as exc: + if exc.response.status_code == 403: logger.error( "403: Could not establish a connection to the `/task_runs/subscriptions/scheduled`" f" endpoint found at:\n\n {PREFECT_API_URL.value()}" From 42d015949e01f0fbcf386ffb2af2928d4295d61d Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Tue, 10 Dec 2024 10:43:14 -0500 Subject: [PATCH 4/5] fix one more import --- src/prefect/client/subscriptions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/prefect/client/subscriptions.py b/src/prefect/client/subscriptions.py index 9506b3ac720a..b7c194f902df 100644 --- a/src/prefect/client/subscriptions.py +++ b/src/prefect/client/subscriptions.py @@ -8,6 +8,7 @@ import websockets.exceptions from starlette.status import WS_1008_POLICY_VIOLATION from typing_extensions import Self +from websockets.asyncio.client import ClientConnection from prefect._internal.schemas.bases import IDBaseModel from prefect.logging import get_logger @@ -44,7 +45,7 @@ def __aiter__(self) -> Self: return self @property - def websocket(self) -> websockets.WebSocketClientProtocol: + def websocket(self) -> ClientConnection: if not self._websocket: raise RuntimeError("Subscription is not connected") return self._websocket From df7c3d9734edad5a988b07edc292bf60287a573e Mon Sep 17 00:00:00 2001 From: "jake@prefect.io" Date: Tue, 10 Dec 2024 10:47:01 -0500 Subject: [PATCH 5/5] fix more tests --- src/prefect/testing/fixtures.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/prefect/testing/fixtures.py b/src/prefect/testing/fixtures.py index 545778427ac1..c94645a1f812 100644 --- a/src/prefect/testing/fixtures.py +++ b/src/prefect/testing/fixtures.py @@ -13,8 +13,10 @@ import pendulum import pytest from starlette.status import WS_1008_POLICY_VIOLATION +from websockets.asyncio.server import Server as WebSocketServer +from websockets.asyncio.server import ServerConnection as WebSocketServerConnection +from websockets.asyncio.server import serve as websocket_serve from websockets.exceptions import ConnectionClosed -from websockets.legacy.server import WebSocketServer, WebSocketServerProtocol, serve from prefect.events import Event from prefect.events.clients import ( @@ -291,7 +293,7 @@ async def events_server( ) -> AsyncGenerator[WebSocketServer, None]: server: WebSocketServer - async def handler(socket: WebSocketServerProtocol) -> None: + async def handler(socket: WebSocketServerConnection) -> None: path = socket.path recorder.connections += 1 if puppeteer.refuse_any_further_connections: @@ -304,7 +306,7 @@ async def handler(socket: WebSocketServerProtocol) -> None: elif path.endswith("/events/out"): await outgoing_events(socket) - async def incoming_events(socket: WebSocketServerProtocol): + async def incoming_events(socket: WebSocketServerConnection): while True: try: message = await socket.recv() @@ -317,7 +319,7 @@ async def incoming_events(socket: WebSocketServerProtocol): if puppeteer.hard_disconnect_after == event.id: raise ValueError("zonk") - async def outgoing_events(socket: WebSocketServerProtocol): + async def outgoing_events(socket: WebSocketServerConnection): # 1. authentication auth_message = json.loads(await socket.recv()) @@ -352,7 +354,9 @@ async def outgoing_events(socket: WebSocketServerProtocol): puppeteer.hard_disconnect_after = None raise ValueError("zonk") - async with serve(handler, host="localhost", port=unused_tcp_port) as server: + async with websocket_serve( + handler, host="localhost", port=unused_tcp_port + ) as server: yield server