Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websockets==14.1 #16313

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compat-tests
2 changes: 1 addition & 1 deletion requirements-client.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ toml >= 0.10.0
typing_extensions >= 4.5.0, < 5.0.0
ujson >= 5.8.0, < 6.0.0
uvicorn >=0.14.0, !=0.29.0
websockets >= 10.4, < 14.0
websockets >= 14.1, < 15.0
5 changes: 3 additions & 2 deletions src/prefect/client/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import websockets.exceptions
from starlette.status import WS_1008_POLICY_VIOLATION
from typing_extensions import Self
from websockets.asyncio.client import ClientConnection

from prefect._internal.schemas.bases import IDBaseModel
from prefect.logging import get_logger
Expand Down Expand Up @@ -44,7 +45,7 @@ def __aiter__(self) -> Self:
return self

@property
def websocket(self) -> websockets.WebSocketClientProtocol:
def websocket(self) -> ClientConnection:
if not self._websocket:
raise RuntimeError("Subscription is not connected")
return self._websocket
Expand Down Expand Up @@ -81,7 +82,7 @@ async def _ensure_connected(self):
)

auth: dict[str, Any] = orjson.loads(await websocket.recv())
assert auth["type"] == "auth_success", auth.get("message")
assert auth["type"] == "auth_success", auth.get("reason")

message: dict[str, Any] = {"type": "subscribe", "keys": self.keys}
if self.client_id:
Expand Down
10 changes: 5 additions & 5 deletions src/prefect/events/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from cachetools import TTLCache
from prometheus_client import Counter
from typing_extensions import Self
from websockets import Subprotocol
from websockets.client import WebSocketClientProtocol, connect
from websockets import Subprotocol, connect
from websockets.asyncio.client import ClientConnection
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
Expand Down Expand Up @@ -241,7 +241,7 @@ def _get_api_url_and_key(
class PrefectEventsClient(EventsClient):
"""A Prefect Events client that streams events to a Prefect server"""

_websocket: Optional[WebSocketClientProtocol]
_websocket: Optional[ClientConnection]
_unconfirmed_events: List[Event]

def __init__(
Expand Down Expand Up @@ -437,7 +437,7 @@ def __init__(
)
self._connect = connect(
self._events_socket_url,
extra_headers={"Authorization": f"bearer {api_key}"},
additional_headers={"Authorization": f"bearer {api_key}"},
)


Expand All @@ -462,7 +462,7 @@ class PrefectEventSubscriber:

"""

_websocket: Optional[WebSocketClientProtocol]
_websocket: Optional[ClientConnection]
_filter: "EventFilter"
_seen_events: MutableMapping[UUID, bool]

Expand Down
6 changes: 3 additions & 3 deletions src/prefect/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import uvicorn
from exceptiongroup import BaseExceptionGroup # novermin
from fastapi import FastAPI
from websockets.exceptions import InvalidStatusCode
from websockets.exceptions import InvalidStatus

from prefect import Task
from prefect._internal.concurrency.api import create_call, from_sync
Expand Down Expand Up @@ -163,8 +163,8 @@ async def start(self) -> None:
logger.info("Starting task worker...")
try:
await self._subscribe_to_task_scheduling()
except InvalidStatusCode as exc:
if exc.status_code == 403:
except InvalidStatus as exc:
if exc.response.status_code == 403:
logger.error(
"403: Could not establish a connection to the `/task_runs/subscriptions/scheduled`"
f" endpoint found at:\n\n {PREFECT_API_URL.value()}"
Expand Down
14 changes: 9 additions & 5 deletions src/prefect/testing/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import pendulum
import pytest
from starlette.status import WS_1008_POLICY_VIOLATION
from websockets.asyncio.server import Server as WebSocketServer
from websockets.asyncio.server import ServerConnection as WebSocketServerConnection
from websockets.asyncio.server import serve as websocket_serve
from websockets.exceptions import ConnectionClosed
from websockets.legacy.server import WebSocketServer, WebSocketServerProtocol, serve

from prefect.events import Event
from prefect.events.clients import (
Expand Down Expand Up @@ -291,7 +293,7 @@ async def events_server(
) -> AsyncGenerator[WebSocketServer, None]:
server: WebSocketServer

async def handler(socket: WebSocketServerProtocol) -> None:
async def handler(socket: WebSocketServerConnection) -> None:
path = socket.path
recorder.connections += 1
if puppeteer.refuse_any_further_connections:
Expand All @@ -304,7 +306,7 @@ async def handler(socket: WebSocketServerProtocol) -> None:
elif path.endswith("/events/out"):
await outgoing_events(socket)

async def incoming_events(socket: WebSocketServerProtocol):
async def incoming_events(socket: WebSocketServerConnection):
while True:
try:
message = await socket.recv()
Expand All @@ -317,7 +319,7 @@ async def incoming_events(socket: WebSocketServerProtocol):
if puppeteer.hard_disconnect_after == event.id:
raise ValueError("zonk")

async def outgoing_events(socket: WebSocketServerProtocol):
async def outgoing_events(socket: WebSocketServerConnection):
# 1. authentication
auth_message = json.loads(await socket.recv())

Expand Down Expand Up @@ -352,7 +354,9 @@ async def outgoing_events(socket: WebSocketServerProtocol):
puppeteer.hard_disconnect_after = None
raise ValueError("zonk")

async with serve(handler, host="localhost", port=unused_tcp_port) as server:
async with websocket_serve(
handler, host="localhost", port=unused_tcp_port
) as server:
yield server


Expand Down
Loading