From 40590a5fb0c5dca63aa96f3b706d7b74ce08dc0a Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Thu, 12 Dec 2024 19:33:15 -0500 Subject: [PATCH] Support http proxies for websockets (#16326) --- .github/workflows/proxy-test.yaml | 74 +++++++++++++++++++++++ requirements-client.txt | 1 + scripts/proxy-test/Dockerfile | 11 ++++ scripts/proxy-test/README.md | 9 +++ scripts/proxy-test/client.py | 28 +++++++++ scripts/proxy-test/docker-compose.yml | 20 ++++++ scripts/proxy-test/requirements.txt | 6 ++ scripts/proxy-test/server.py | 10 +++ scripts/proxy-test/squid.conf | 5 ++ src/prefect/events/clients.py | 59 ++++++++++++++++-- tests/events/client/test_events_client.py | 2 +- tests/utilities/test_proxy.py | 42 +++++++++++++ 12 files changed, 262 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/proxy-test.yaml create mode 100644 scripts/proxy-test/Dockerfile create mode 100644 scripts/proxy-test/README.md create mode 100644 scripts/proxy-test/client.py create mode 100644 scripts/proxy-test/docker-compose.yml create mode 100644 scripts/proxy-test/requirements.txt create mode 100644 scripts/proxy-test/server.py create mode 100644 scripts/proxy-test/squid.conf create mode 100644 tests/utilities/test_proxy.py diff --git a/.github/workflows/proxy-test.yaml b/.github/workflows/proxy-test.yaml new file mode 100644 index 000000000000..3d8a1512a107 --- /dev/null +++ b/.github/workflows/proxy-test.yaml @@ -0,0 +1,74 @@ +# This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a +# simple server and a squid proxy server. The proxy server is inaccessible from the host machine, only the proxy +# so we can confirm the proxy is actually working. + +name: Proxy Test +on: + pull_request: + paths: + - .github/workflows/proxy-test.yaml + - scripts/proxy-test/* + - "src/prefect/events/clients.py" + - requirements.txt + - requirements-client.txt + - requirements-dev.txt + push: + branches: + - main + paths: + - .github/workflows/proxy-test.yaml + - scripts/proxy-test/* + - "src/prefect/events/clients.py" + - requirements.txt + - requirements-client.txt + - requirements-dev.txt + +jobs: + proxy-test: + name: Proxy Test + timeout-minutes: 10 + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + fetch-depth: 0 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + id: setup_python + with: + python-version: "3.10" + + - name: Create Docker networks + run: | + docker network create internal_net --internal + docker network create external_net + + - name: Start API server container + working-directory: scripts/proxy-test + run: | + docker build -t api-server . + docker run -d --network internal_net --name server api-server + + - name: Start Squid Proxy container + run: | + docker run -d \ + --network internal_net \ + --network external_net \ + -p 3128:3128 \ + -v $(pwd)/scripts/proxy-test/squid.conf:/etc/squid/squid.conf \ + --name proxy \ + ubuntu/squid + + - name: Install Dependencies + run: | + python -m pip install -U uv + uv pip install --upgrade --system . + + - name: Run Proxy Tests + env: + HTTP_PROXY: http://localhost:3128 + HTTPS_PROXY: http://localhost:3128 + run: python scripts/proxy-test/client.py diff --git a/requirements-client.txt b/requirements-client.txt index de5e2b5ab1e5..e5424a1c85c1 100644 --- a/requirements-client.txt +++ b/requirements-client.txt @@ -26,6 +26,7 @@ pydantic_extra_types >= 2.8.2, < 3.0.0 pydantic_settings > 2.2.1 python_dateutil >= 2.8.2, < 3.0.0 python-slugify >= 5.0, < 9.0 +python-socks[asyncio] >= 2.5.3, < 3.0 pyyaml >= 5.4.1, < 7.0.0 rfc3339-validator >= 0.1.4, < 0.2.0 rich >= 11.0, < 14.0 diff --git a/scripts/proxy-test/Dockerfile b/scripts/proxy-test/Dockerfile new file mode 100644 index 000000000000..93b6c4db9107 --- /dev/null +++ b/scripts/proxy-test/Dockerfile @@ -0,0 +1,11 @@ +FROM python:3.11-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install uv +RUN uv pip install --no-cache-dir --system -r requirements.txt + +COPY server.py . + +CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/scripts/proxy-test/README.md b/scripts/proxy-test/README.md new file mode 100644 index 000000000000..76a8c8fdc55f --- /dev/null +++ b/scripts/proxy-test/README.md @@ -0,0 +1,9 @@ +This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a +simple server and a squid proxy server. The proxy server is inaccessible from the host machine, so we +can confirm the proxy connection is working. + +``` +$ uv pip install -r requirements.txt +$ docker compose up --build +$ python client.py +``` diff --git a/scripts/proxy-test/client.py b/scripts/proxy-test/client.py new file mode 100644 index 000000000000..e07e8450b29a --- /dev/null +++ b/scripts/proxy-test/client.py @@ -0,0 +1,28 @@ +import asyncio +import os + +from prefect.events.clients import websocket_connect + +PROXY_URL = "http://localhost:3128" +WS_SERVER_URL = "ws://server:8000/ws" + + +async def test_websocket_proxy_with_compat(): + """WebSocket through proxy with proxy compatibility code - should work""" + os.environ["HTTP_PROXY"] = PROXY_URL + + async with websocket_connect(WS_SERVER_URL) as websocket: + message = "Hello!" + await websocket.send(message) + response = await websocket.recv() + print("Response: ", response) + assert response == f"Server received: {message}" + + +async def main(): + print("Testing WebSocket through proxy with compatibility code") + await test_websocket_proxy_with_compat() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/proxy-test/docker-compose.yml b/scripts/proxy-test/docker-compose.yml new file mode 100644 index 000000000000..ab26fdaae98c --- /dev/null +++ b/scripts/proxy-test/docker-compose.yml @@ -0,0 +1,20 @@ +services: + server: + build: . + networks: + - internal_net + + forward_proxy: + image: ubuntu/squid + ports: + - "3128:3128" + volumes: + - ./squid.conf:/etc/squid/squid.conf + networks: + - internal_net + - external_net + +networks: + internal_net: + internal: true + external_net: diff --git a/scripts/proxy-test/requirements.txt b/scripts/proxy-test/requirements.txt new file mode 100644 index 000000000000..9cfd79360c5a --- /dev/null +++ b/scripts/proxy-test/requirements.txt @@ -0,0 +1,6 @@ +fastapi==0.111.1 +uvicorn==0.28.1 +uv==0.5.7 +websockets==13.1 +python-socks==2.5.3 +httpx==0.28.1 diff --git a/scripts/proxy-test/server.py b/scripts/proxy-test/server.py new file mode 100644 index 000000000000..4f4498d04a7a --- /dev/null +++ b/scripts/proxy-test/server.py @@ -0,0 +1,10 @@ +from fastapi import FastAPI, WebSocket + +app = FastAPI() + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + async for data in websocket.iter_text(): + await websocket.send_text(f"Server received: {data}") diff --git a/scripts/proxy-test/squid.conf b/scripts/proxy-test/squid.conf new file mode 100644 index 000000000000..8c9261d02d61 --- /dev/null +++ b/scripts/proxy-test/squid.conf @@ -0,0 +1,5 @@ +http_port 3128 +acl CONNECT method CONNECT +acl SSL_ports port 443 8000 +http_access allow CONNECT SSL_ports +http_access allow all diff --git a/src/prefect/events/clients.py b/src/prefect/events/clients.py index bd09eb3ab20c..4efd7ba1107a 100644 --- a/src/prefect/events/clients.py +++ b/src/prefect/events/clients.py @@ -1,11 +1,13 @@ import abc import asyncio +import os from types import TracebackType from typing import ( TYPE_CHECKING, Any, ClassVar, Dict, + Generator, List, MutableMapping, Optional, @@ -13,20 +15,22 @@ Type, cast, ) +from urllib.parse import urlparse from uuid import UUID import orjson import pendulum from cachetools import TTLCache from prometheus_client import Counter +from python_socks.async_.asyncio import Proxy from typing_extensions import Self from websockets import Subprotocol -from websockets.client import WebSocketClientProtocol, connect from websockets.exceptions import ( ConnectionClosed, ConnectionClosedError, ConnectionClosedOK, ) +from websockets.legacy.client import Connect, WebSocketClientProtocol from prefect.events import Event from prefect.logging import get_logger @@ -80,6 +84,53 @@ def events_out_socket_from_api_url(url: str): return http_to_ws(url) + "/events/out" +class WebsocketProxyConnect(Connect): + def __init__(self: Self, uri: str, **kwargs: Any): + # super() is intentionally deferred to the _proxy_connect method + # to allow for the socket to be established first + + self.uri = uri + self._kwargs = kwargs + + u = urlparse(uri) + host = u.hostname + + if u.scheme == "ws": + port = u.port or 80 + proxy_url = os.environ.get("HTTP_PROXY") + elif u.scheme == "wss": + port = u.port or 443 + proxy_url = os.environ.get("HTTPS_PROXY") + kwargs["server_hostname"] = host + else: + raise ValueError( + "Unsupported scheme %s. Expected 'ws' or 'wss'. " % u.scheme + ) + + self._proxy = Proxy.from_url(proxy_url) if proxy_url else None + self._host = host + self._port = port + + async def _proxy_connect(self: Self) -> WebSocketClientProtocol: + if self._proxy: + sock = await self._proxy.connect( + dest_host=self._host, + dest_port=self._port, + ) + self._kwargs["sock"] = sock + + super().__init__(self.uri, **self._kwargs) + proto = await self.__await_impl__() + return proto + + def __await__(self: Self) -> Generator[Any, None, WebSocketClientProtocol]: + return self._proxy_connect().__await__() + + +def websocket_connect(uri: str, **kwargs: Any) -> WebsocketProxyConnect: + return WebsocketProxyConnect(uri, **kwargs) + + def get_events_client( reconnection_attempts: int = 10, checkpoint_every: int = 700, @@ -265,7 +316,7 @@ def __init__( ) self._events_socket_url = events_in_socket_from_api_url(api_url) - self._connect = connect(self._events_socket_url) + self._connect = websocket_connect(self._events_socket_url) self._websocket = None self._reconnection_attempts = reconnection_attempts self._unconfirmed_events = [] @@ -435,7 +486,7 @@ def __init__( reconnection_attempts=reconnection_attempts, checkpoint_every=checkpoint_every, ) - self._connect = connect( + self._connect = websocket_connect( self._events_socket_url, extra_headers={"Authorization": f"bearer {api_key}"}, ) @@ -494,7 +545,7 @@ def __init__( logger.debug("Connecting to %s", socket_url) - self._connect = connect( + self._connect = websocket_connect( socket_url, subprotocols=[Subprotocol("prefect")], ) diff --git a/tests/events/client/test_events_client.py b/tests/events/client/test_events_client.py index 9acf2ef82b2b..5b6970ae7e3b 100644 --- a/tests/events/client/test_events_client.py +++ b/tests/events/client/test_events_client.py @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def mock_connect(*args, **kwargs): return MockConnect() - monkeypatch.setattr("prefect.events.clients.connect", mock_connect) + monkeypatch.setattr("prefect.events.clients.websocket_connect", mock_connect) with caplog.at_level(logging.WARNING): with pytest.raises(Exception, match="Connection failed"): diff --git a/tests/utilities/test_proxy.py b/tests/utilities/test_proxy.py new file mode 100644 index 000000000000..49723aea11fb --- /dev/null +++ b/tests/utilities/test_proxy.py @@ -0,0 +1,42 @@ +from unittest.mock import Mock + +from prefect.events.clients import WebsocketProxyConnect + + +def test_init_ws_without_proxy(): + client = WebsocketProxyConnect("ws://example.com") + assert client.uri == "ws://example.com" + assert client._host == "example.com" + assert client._port == 80 + assert client._proxy is None + + +def test_init_wss_without_proxy(): + client = WebsocketProxyConnect("wss://example.com") + assert client.uri == "wss://example.com" + assert client._host == "example.com" + assert client._port == 443 + assert "server_hostname" in client._kwargs + assert client._proxy is None + + +def test_init_ws_with_proxy(monkeypatch): + monkeypatch.setenv("HTTP_PROXY", "http://proxy:3128") + mock_proxy = Mock() + monkeypatch.setattr("prefect.events.clients.Proxy", mock_proxy) + + client = WebsocketProxyConnect("ws://example.com") + + mock_proxy.from_url.assert_called_once_with("http://proxy:3128") + assert client._proxy is not None + + +def test_init_wss_with_proxy(monkeypatch): + monkeypatch.setenv("HTTPS_PROXY", "https://proxy:3128") + mock_proxy = Mock() + monkeypatch.setattr("prefect.events.clients.Proxy", mock_proxy) + + client = WebsocketProxyConnect("wss://example.com") + + mock_proxy.from_url.assert_called_once_with("https://proxy:3128") + assert client._proxy is not None