Skip to content

Commit

Permalink
Support http proxies for websockets (#16326)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakekaplan authored Dec 13, 2024
1 parent 547d022 commit 40590a5
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 5 deletions.
74 changes: 74 additions & 0 deletions .github/workflows/proxy-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a
# simple server and a squid proxy server. The proxy server is inaccessible from the host machine, only the proxy
# so we can confirm the proxy is actually working.

name: Proxy Test
on:
pull_request:
paths:
- .github/workflows/proxy-test.yaml
- scripts/proxy-test/*
- "src/prefect/events/clients.py"
- requirements.txt
- requirements-client.txt
- requirements-dev.txt
push:
branches:
- main
paths:
- .github/workflows/proxy-test.yaml
- scripts/proxy-test/*
- "src/prefect/events/clients.py"
- requirements.txt
- requirements-client.txt
- requirements-dev.txt

jobs:
proxy-test:
name: Proxy Test
timeout-minutes: 10
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v5
id: setup_python
with:
python-version: "3.10"

- name: Create Docker networks
run: |
docker network create internal_net --internal
docker network create external_net
- name: Start API server container
working-directory: scripts/proxy-test
run: |
docker build -t api-server .
docker run -d --network internal_net --name server api-server
- name: Start Squid Proxy container
run: |
docker run -d \
--network internal_net \
--network external_net \
-p 3128:3128 \
-v $(pwd)/scripts/proxy-test/squid.conf:/etc/squid/squid.conf \
--name proxy \
ubuntu/squid
- name: Install Dependencies
run: |
python -m pip install -U uv
uv pip install --upgrade --system .
- name: Run Proxy Tests
env:
HTTP_PROXY: http://localhost:3128
HTTPS_PROXY: http://localhost:3128
run: python scripts/proxy-test/client.py
1 change: 1 addition & 0 deletions requirements-client.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pydantic_extra_types >= 2.8.2, < 3.0.0
pydantic_settings > 2.2.1
python_dateutil >= 2.8.2, < 3.0.0
python-slugify >= 5.0, < 9.0
python-socks[asyncio] >= 2.5.3, < 3.0
pyyaml >= 5.4.1, < 7.0.0
rfc3339-validator >= 0.1.4, < 0.2.0
rich >= 11.0, < 14.0
Expand Down
11 changes: 11 additions & 0 deletions scripts/proxy-test/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install uv
RUN uv pip install --no-cache-dir --system -r requirements.txt

COPY server.py .

CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
9 changes: 9 additions & 0 deletions scripts/proxy-test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
This is a simple test to ensure we can make a websocket connection through a proxy server. It sets up a
simple server and a squid proxy server. The proxy server is inaccessible from the host machine, so we
can confirm the proxy connection is working.

```
$ uv pip install -r requirements.txt
$ docker compose up --build
$ python client.py
```
28 changes: 28 additions & 0 deletions scripts/proxy-test/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio
import os

from prefect.events.clients import websocket_connect

PROXY_URL = "http://localhost:3128"
WS_SERVER_URL = "ws://server:8000/ws"


async def test_websocket_proxy_with_compat():
"""WebSocket through proxy with proxy compatibility code - should work"""
os.environ["HTTP_PROXY"] = PROXY_URL

async with websocket_connect(WS_SERVER_URL) as websocket:
message = "Hello!"
await websocket.send(message)
response = await websocket.recv()
print("Response: ", response)
assert response == f"Server received: {message}"


async def main():
print("Testing WebSocket through proxy with compatibility code")
await test_websocket_proxy_with_compat()


if __name__ == "__main__":
asyncio.run(main())
20 changes: 20 additions & 0 deletions scripts/proxy-test/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
services:
server:
build: .
networks:
- internal_net

forward_proxy:
image: ubuntu/squid
ports:
- "3128:3128"
volumes:
- ./squid.conf:/etc/squid/squid.conf
networks:
- internal_net
- external_net

networks:
internal_net:
internal: true
external_net:
6 changes: 6 additions & 0 deletions scripts/proxy-test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fastapi==0.111.1
uvicorn==0.28.1
uv==0.5.7
websockets==13.1
python-socks==2.5.3
httpx==0.28.1
10 changes: 10 additions & 0 deletions scripts/proxy-test/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from fastapi import FastAPI, WebSocket

app = FastAPI()


@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
async for data in websocket.iter_text():
await websocket.send_text(f"Server received: {data}")
5 changes: 5 additions & 0 deletions scripts/proxy-test/squid.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
http_port 3128
acl CONNECT method CONNECT
acl SSL_ports port 443 8000
http_access allow CONNECT SSL_ports
http_access allow all
59 changes: 55 additions & 4 deletions src/prefect/events/clients.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
import abc
import asyncio
import os
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generator,
List,
MutableMapping,
Optional,
Tuple,
Type,
cast,
)
from urllib.parse import urlparse
from uuid import UUID

import orjson
import pendulum
from cachetools import TTLCache
from prometheus_client import Counter
from python_socks.async_.asyncio import Proxy
from typing_extensions import Self
from websockets import Subprotocol
from websockets.client import WebSocketClientProtocol, connect
from websockets.exceptions import (
ConnectionClosed,
ConnectionClosedError,
ConnectionClosedOK,
)
from websockets.legacy.client import Connect, WebSocketClientProtocol

from prefect.events import Event
from prefect.logging import get_logger
Expand Down Expand Up @@ -80,6 +84,53 @@ def events_out_socket_from_api_url(url: str):
return http_to_ws(url) + "/events/out"


class WebsocketProxyConnect(Connect):
def __init__(self: Self, uri: str, **kwargs: Any):
# super() is intentionally deferred to the _proxy_connect method
# to allow for the socket to be established first

self.uri = uri
self._kwargs = kwargs

u = urlparse(uri)
host = u.hostname

if u.scheme == "ws":
port = u.port or 80
proxy_url = os.environ.get("HTTP_PROXY")
elif u.scheme == "wss":
port = u.port or 443
proxy_url = os.environ.get("HTTPS_PROXY")
kwargs["server_hostname"] = host
else:
raise ValueError(
"Unsupported scheme %s. Expected 'ws' or 'wss'. " % u.scheme
)

self._proxy = Proxy.from_url(proxy_url) if proxy_url else None
self._host = host
self._port = port

async def _proxy_connect(self: Self) -> WebSocketClientProtocol:
if self._proxy:
sock = await self._proxy.connect(
dest_host=self._host,
dest_port=self._port,
)
self._kwargs["sock"] = sock

super().__init__(self.uri, **self._kwargs)
proto = await self.__await_impl__()
return proto

def __await__(self: Self) -> Generator[Any, None, WebSocketClientProtocol]:
return self._proxy_connect().__await__()


def websocket_connect(uri: str, **kwargs: Any) -> WebsocketProxyConnect:
return WebsocketProxyConnect(uri, **kwargs)


def get_events_client(
reconnection_attempts: int = 10,
checkpoint_every: int = 700,
Expand Down Expand Up @@ -265,7 +316,7 @@ def __init__(
)

self._events_socket_url = events_in_socket_from_api_url(api_url)
self._connect = connect(self._events_socket_url)
self._connect = websocket_connect(self._events_socket_url)
self._websocket = None
self._reconnection_attempts = reconnection_attempts
self._unconfirmed_events = []
Expand Down Expand Up @@ -435,7 +486,7 @@ def __init__(
reconnection_attempts=reconnection_attempts,
checkpoint_every=checkpoint_every,
)
self._connect = connect(
self._connect = websocket_connect(
self._events_socket_url,
extra_headers={"Authorization": f"bearer {api_key}"},
)
Expand Down Expand Up @@ -494,7 +545,7 @@ def __init__(

logger.debug("Connecting to %s", socket_url)

self._connect = connect(
self._connect = websocket_connect(
socket_url,
subprotocols=[Subprotocol("prefect")],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/events/client/test_events_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
def mock_connect(*args, **kwargs):
return MockConnect()

monkeypatch.setattr("prefect.events.clients.connect", mock_connect)
monkeypatch.setattr("prefect.events.clients.websocket_connect", mock_connect)

with caplog.at_level(logging.WARNING):
with pytest.raises(Exception, match="Connection failed"):
Expand Down
42 changes: 42 additions & 0 deletions tests/utilities/test_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from unittest.mock import Mock

from prefect.events.clients import WebsocketProxyConnect


def test_init_ws_without_proxy():
client = WebsocketProxyConnect("ws://example.com")
assert client.uri == "ws://example.com"
assert client._host == "example.com"
assert client._port == 80
assert client._proxy is None


def test_init_wss_without_proxy():
client = WebsocketProxyConnect("wss://example.com")
assert client.uri == "wss://example.com"
assert client._host == "example.com"
assert client._port == 443
assert "server_hostname" in client._kwargs
assert client._proxy is None


def test_init_ws_with_proxy(monkeypatch):
monkeypatch.setenv("HTTP_PROXY", "http://proxy:3128")
mock_proxy = Mock()
monkeypatch.setattr("prefect.events.clients.Proxy", mock_proxy)

client = WebsocketProxyConnect("ws://example.com")

mock_proxy.from_url.assert_called_once_with("http://proxy:3128")
assert client._proxy is not None


def test_init_wss_with_proxy(monkeypatch):
monkeypatch.setenv("HTTPS_PROXY", "https://proxy:3128")
mock_proxy = Mock()
monkeypatch.setattr("prefect.events.clients.Proxy", mock_proxy)

client = WebsocketProxyConnect("wss://example.com")

mock_proxy.from_url.assert_called_once_with("https://proxy:3128")
assert client._proxy is not None

0 comments on commit 40590a5

Please sign in to comment.