diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f6be2ad..8781869 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -37,4 +37,4 @@ jobs: flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - pytest + pytest -v --capture=tee-sys diff --git a/README.md b/README.md index e345817..1fb33fd 100644 --- a/README.md +++ b/README.md @@ -153,6 +153,26 @@ from fastapi_websocket_rpc.logger import logging_config, LoggingModes logging_config.set_mode(LoggingModes.UVICORN) ``` +## HTTP(S) Proxy +By default, fastapi-websocket-rpc uses websockets module as websocket client handler. This does not support HTTP(S) Proxy, see https://github.com/python-websockets/websockets/issues/364 . If the ability to use a proxy is important to, another websocket client implementation can be used, e.g. websocket-client (https://websocket-client.readthedocs.io). Here is how to use it. Installation: + +``` +pip install websocket-client +``` + +Then use websocket_client_handler_cls parameter: + +```python +import asyncio +from fastapi_websocket_rpc import RpcMethodsBase, WebSocketRpcClient, ProxyEnabledWebSocketClientHandler + +async def run_client(uri): + async with WebSocketRpcClient(uri, RpcMethodsBase(), websocket_client_handler_cls = ProxyEnabledWebSocketClientHandler) as client: +``` + +Just set standard environment variables (lowercase and uppercase works): http_proxy, https_proxy, and no_proxy before running python script. + + ## Pull requests - welcome! - Please include tests for new features diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fastapi_websocket_rpc/__init__.py b/fastapi_websocket_rpc/__init__.py index 08c7dba..662fe50 100644 --- a/fastapi_websocket_rpc/__init__.py +++ b/fastapi_websocket_rpc/__init__.py @@ -1,5 +1,7 @@ from .rpc_methods import RpcMethodsBase, RpcUtilityMethods from .websocket_rpc_client import WebSocketRpcClient +from .websocket_rpc_client import ProxyEnabledWebSocketClientHandler +from .websocket_rpc_client import WebSocketsClientHandler from .websocket_rpc_endpoint import WebsocketRPCEndpoint from .rpc_channel import RpcChannel from .schemas import WebSocketFrameType diff --git a/fastapi_websocket_rpc/simplewebsocket.py b/fastapi_websocket_rpc/simplewebsocket.py index ec62ee6..f66b3d9 100644 --- a/fastapi_websocket_rpc/simplewebsocket.py +++ b/fastapi_websocket_rpc/simplewebsocket.py @@ -9,10 +9,15 @@ class SimpleWebSocket(ABC): Abstract base class for all websocket related wrappers. """ + @abstractmethod + def connect(self, uri: str, **connect_kwargs): + pass + @abstractmethod def send(self, msg): pass + # If return None, then it means Connection is closed, and we stop receiving and close. @abstractmethod def recv(self): pass @@ -26,6 +31,9 @@ class JsonSerializingWebSocket(SimpleWebSocket): def __init__(self, websocket: SimpleWebSocket): self._websocket = websocket + async def connect(self, uri: str, **connect_kwargs): + await self._websocket.connect(uri, **connect_kwargs) + def _serialize(self, msg): return pydantic_serialize(msg) @@ -37,8 +45,10 @@ async def send(self, msg): async def recv(self): msg = await self._websocket.recv() - + if msg is None: + return None return self._deserialize(msg) async def close(self, code: int = 1000): await self._websocket.close(code) + diff --git a/fastapi_websocket_rpc/websocket_rpc_client.py b/fastapi_websocket_rpc/websocket_rpc_client.py index 2cbf472..7b3079b 100644 --- a/fastapi_websocket_rpc/websocket_rpc_client.py +++ b/fastapi_websocket_rpc/websocket_rpc_client.py @@ -15,6 +15,146 @@ logger = get_logger("RPC_CLIENT") +try: + import websocket +except ImportError: + # Websocket-client optional module not installed. + pass + +class ProxyEnabledWebSocketClientHandler (SimpleWebSocket): + """ + Handler that use https://websocket-client.readthedocs.io/en/latest module. + This implementation supports HTTP proxy, though HTTP_PROXY and HTTPS_PROXY environment variable. + This is not documented, but in code, see https://github.com/websocket-client/websocket-client/blob/master/websocket/_url.py#L163 + The module is not written as coroutine: https://websocket-client.readthedocs.io/en/latest/threading.html#asyncio-library-usage, so + as a workaround, the send/recv are called in "run_in_executor" method. + TODO: remove this implementation after https://github.com/python-websockets/websockets/issues/364 is fixed and use WebSocketsClientHandler instead. + + Note: the connect timeout, if not specified, is the default socket connect timeout, which could be around 2min, so a bit longer than WebSocketsClientHandler. + """ + def __init__(self): + self._websocket = None + + """ + Args: + **kwargs: Additional args passed to connect + https://websocket-client.readthedocs.io/en/latest/examples.html#connection-options + https://websocket-client.readthedocs.io/en/latest/core.html#websocket._core.WebSocket.connect + """ + async def connect(self, uri: str, **connect_kwargs): + try: + self._websocket = await asyncio.get_event_loop().run_in_executor(None, websocket.create_connection, uri, **connect_kwargs) + # See https://websocket-client.readthedocs.io/en/latest/exceptions.html + except websocket._exceptions.WebSocketAddressException: + logger.info("websocket address info cannot be found") + raise + except websocket._exceptions.WebSocketBadStatusException: + logger.info("bad handshake status code") + raise + except websocket._exceptions.WebSocketConnectionClosedException: + logger.info("remote host closed the connection or some network error happened") + raise + except websocket._exceptions.WebSocketPayloadException: + logger.info( + f"WebSocket payload is invalid") + raise + except websocket._exceptions.WebSocketProtocolException: + logger.info(f"WebSocket protocol is invalid") + raise + except websocket._exceptions.WebSocketProxyException: + logger.info(f"proxy error occurred") + raise + except OSError as err: + logger.info("RPC Connection failed - %s", err) + raise + except Exception as err: + logger.exception("RPC Error") + raise + + async def send(self, msg): + if self._websocket is None: + # connect must be called before. + logging.error("Websocket connect() must be called before.") + await asyncio.get_event_loop().run_in_executor(None, self._websocket.send, msg) + + async def recv(self): + if self._websocket is None: + # connect must be called before. + logging.error("Websocket connect() must be called before.") + try: + msg = await asyncio.get_event_loop().run_in_executor(None, self._websocket.recv) + except websocket._exceptions.WebSocketConnectionClosedException as err: + logger.debug("Connection closed.", exc_info=True) + # websocket.WebSocketConnectionClosedException means remote host closed the connection or some network error happened + # Returning None to ensure we get out of the loop, with no Exception. + return None + return msg + + async def close(self, code: int = 1000): + if self._websocket is not None: + # Case opened, we have something to close. + self._websocket.close(code) + +class WebSocketsClientHandler(SimpleWebSocket): + """ + Handler that use https://websockets.readthedocs.io/en/stable module. + This implementation does not support HTTP proxy (see https://github.com/python-websockets/websockets/issues/364). + """ + def __init__(self): + self._websocket = None + + """ + Args: + **kwargs: Additional args passed to connect + https://websockets.readthedocs.io/en/stable/reference/asyncio/client.html#opening-a-connection + """ + async def connect(self, uri: str, **connect_kwargs): + try: + self._websocket = await websockets.connect(uri, **connect_kwargs) + except ConnectionRefusedError: + logger.info("RPC connection was refused by server") + raise + except ConnectionClosedError: + logger.info("RPC connection lost") + raise + except ConnectionClosedOK: + logger.info("RPC connection closed") + raise + except InvalidStatusCode as err: + logger.info( + f"RPC Websocket failed - with invalid status code {err.status_code}") + raise + except WebSocketException as err: + logger.info(f"RPC Websocket failed - with {err}") + raise + except OSError as err: + logger.info("RPC Connection failed - %s", err) + raise + except Exception as err: + logger.exception("RPC Error") + raise + + async def send(self, msg): + if self._websocket is None: + # connect must be called before. + logging.error("Websocket connect() must be called before.") + await self._websocket.send(msg) + + async def recv(self): + if self._websocket is None: + # connect must be called before. + logging.error("Websocket connect() must be called before.") + try: + msg = await self._websocket.recv() + except websockets.exceptions.ConnectionClosed: + logger.debug("Connection closed.", exc_info=True) + return None + return msg + + async def close(self, code: int = 1000): + if self._websocket is not None: + # Case opened, we have something to close. + self._websocket.close(code) def isNotInvalidStatusCode(value): return not isinstance(value, InvalidStatusCode) @@ -59,6 +199,7 @@ def __init__(self, uri: str, methods: RpcMethodsBase = None, on_disconnect: List[OnDisconnectCallback] = None, keep_alive: float = 0, serializing_socket_cls: Type[SimpleWebSocket] = JsonSerializingWebSocket, + websocket_client_handler_cls: Type[SimpleWebSocket] = WebSocketsClientHandler, **kwargs): """ Args: @@ -71,8 +212,7 @@ def __init__(self, uri: str, methods: RpcMethodsBase = None, on_disconnect (List[Coroutine]): callbacks on connection termination (each callback is called with the channel) keep_alive(float): interval in seconds to send a keep-alive ping, Defaults to 0, which means keep alive is disabled. - **kwargs: Additional args passed to connect (@see class Connect at websockets/client.py) - https://websockets.readthedocs.io/en/stable/api.html#websockets.client.connect + **kwargs: Additional args passed to connect, depends on websocket_client_handler_cls usage: @@ -105,15 +245,24 @@ def __init__(self, uri: str, methods: RpcMethodsBase = None, self._on_connect = on_connect # serialization self._serializing_socket_cls = serializing_socket_cls + # websocket client implementation + self._websocket_client_handler_cls = websocket_client_handler_cls async def __connect__(self): + logger.info(f"Trying server - {self.uri}") + try: + raw_ws = self._websocket_client_handler_cls() + # Wrap socket in our serialization class + self.ws = self._serializing_socket_cls(raw_ws) + except: + logger.exception("Class instantiation error.") + raise + # No try/catch for connect() to avoid double error logging. Any exception from the method must be handled by + # itself for logging, then raised and caught outside of connect() (e.g.: for retry purpose). + # Start connection + await self.ws.connect(self.uri, **self.connect_kwargs) try: try: - logger.info(f"Trying server - {self.uri}") - # Start connection - raw_ws = await websockets.connect(self.uri, **self.connect_kwargs) - # Wrap socket in our serialization class - self.ws = self._serializing_socket_cls(raw_ws) # Init an RPC channel to work on-top of the connection self.channel = RpcChannel( self.methods, self.ws, default_response_timeout=self.default_response_timeout) @@ -137,25 +286,6 @@ async def __connect__(self): await self.channel.close() self.cancel_tasks() raise - except ConnectionRefusedError: - logger.info("RPC connection was refused by server") - raise - except ConnectionClosedError: - logger.info("RPC connection lost") - raise - except ConnectionClosedOK: - logger.info("RPC connection closed") - raise - except InvalidStatusCode as err: - logger.info( - f"RPC Websocket failed - with invalid status code {err.status_code}") - raise - except WebSocketException as err: - logger.info(f"RPC Websocket failed - with {err}") - raise - except OSError as err: - logger.info("RPC Connection failed - %s", err) - raise except Exception as err: logger.exception("RPC Error") raise @@ -200,15 +330,18 @@ async def reader(self): try: while True: raw_message = await self.ws.recv() - await self.channel.on_message(raw_message) + if raw_message is None: + # None is a special case where connection is closed. + logger.info("Connection was terminated.") + await self.close() + break + else: + await self.channel.on_message(raw_message) # Graceful external termination options # task was canceled except asyncio.CancelledError: pass - except websockets.exceptions.ConnectionClosed: - logger.info("Connection was terminated.") - await self.close() - except: + except Exception as err: logger.exception("RPC Reader task failed") raise diff --git a/fastapi_websocket_rpc/websocket_rpc_endpoint.py b/fastapi_websocket_rpc/websocket_rpc_endpoint.py index 5a887e6..4197728 100644 --- a/fastapi_websocket_rpc/websocket_rpc_endpoint.py +++ b/fastapi_websocket_rpc/websocket_rpc_endpoint.py @@ -21,6 +21,10 @@ def __init__(self, websocket: WebSocket, frame_type: WebSocketFrameType = WebSoc self.websocket = websocket self.frame_type = frame_type + # This method is only useful on websocket_rpc_client. Here on endpoint file, it has nothing to connect to. + async def connect(self, uri: str, **connect_kwargs): + pass + @property def send(self): if self.frame_type == WebSocketFrameType.Binary: diff --git a/pytest.ini b/pytest.ini index 16c88ba..359a939 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,9 @@ # Handling DeprecationWarning 'asyncio_mode' default value [pytest] asyncio_mode = strict +pythonpath = . +log_cli = 1 +log_cli_level = DEBUG +log_cli_format = %(asctime)s [%(levelname)s] (%(filename)s:%(lineno)s) %(message)s +log_date_format = %Y-%m-%d %H:%M:%S + diff --git a/setup.py b/setup.py index 30cd797..9af6d4c 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,15 @@ from setuptools import setup, find_packages - +import os def get_requirements(env=""): if env: env = "-{}".format(env) with open("requirements{}.txt".format(env)) as fp: - return [x.strip() for x in fp.read().split("\n") if not x.startswith("#")] - + requirements = [x.strip() for x in fp.read().split("\n") if not x.startswith("#")] + withWebsocketClient = os.environ.get("WITH_WEBSOCKET_CLIENT", "False") + if bool(withWebsocketClient): + requirements.append("websocket-client>=1.1.0") + return requirements with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() diff --git a/tests/advanced_rpc_test.py b/tests/advanced_rpc_test.py index 654e970..d82a324 100644 --- a/tests/advanced_rpc_test.py +++ b/tests/advanced_rpc_test.py @@ -1,10 +1,6 @@ import os import sys -# Add parent path to use local src as package for tests -sys.path.append(os.path.abspath(os.path.join( - os.path.dirname(__file__), os.path.pardir))) - import time import asyncio from multiprocessing import Process diff --git a/tests/basic_rpc_test.py b/tests/basic_rpc_test.py index 68b1a87..f637547 100644 --- a/tests/basic_rpc_test.py +++ b/tests/basic_rpc_test.py @@ -2,10 +2,6 @@ import os import sys -# Add parent path to use local src as package for tests -sys.path.append(os.path.abspath(os.path.join( - os.path.dirname(__file__), os.path.pardir))) - import asyncio from multiprocessing import Process @@ -14,7 +10,7 @@ from fastapi import FastAPI from fastapi_websocket_rpc.rpc_methods import RpcUtilityMethods -from fastapi_websocket_rpc.logger import logging_config, LoggingModes +from fastapi_websocket_rpc.logger import logging_config, LoggingModes, get_logger from fastapi_websocket_rpc.websocket_rpc_client import WebSocketRpcClient from fastapi_websocket_rpc.websocket_rpc_endpoint import WebsocketRPCEndpoint from fastapi_websocket_rpc.utils import gen_uid diff --git a/tests/binary_rpc_test.py b/tests/binary_rpc_test.py index e34c7f9..b9858ed 100644 --- a/tests/binary_rpc_test.py +++ b/tests/binary_rpc_test.py @@ -2,11 +2,6 @@ import os import sys -# Add parent path to use local src as package for tests -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) -) - import json from multiprocessing import Process @@ -15,7 +10,7 @@ from fastapi import FastAPI from fastapi_websocket_rpc import WebSocketFrameType -from fastapi_websocket_rpc.logger import LoggingModes, logging_config +from fastapi_websocket_rpc.logger import LoggingModes, logging_config, get_logger from fastapi_websocket_rpc.rpc_methods import RpcUtilityMethods from fastapi_websocket_rpc.simplewebsocket import SimpleWebSocket from fastapi_websocket_rpc.utils import pydantic_serialize @@ -25,6 +20,8 @@ # Set debug logs (and direct all logs to UVICORN format) logging_config.set_mode(LoggingModes.UVICORN, logging.DEBUG) +logger = get_logger(__name__) + # Configurable PORT = int(os.environ.get("PORT") or "9000") uri = f"ws://localhost:{PORT}/ws" @@ -34,6 +31,9 @@ class BinarySerializingWebSocket(SimpleWebSocket): def __init__(self, websocket: SimpleWebSocket): self._websocket = websocket + async def connect(self, uri: str, **connect_kwargs): + await self._websocket.connect(uri, **connect_kwargs) + def _serialize(self, msg): return pydantic_serialize(msg).encode() @@ -77,14 +77,18 @@ async def test_echo(server): """ Test basic RPC with a simple echo """ + logger.debug("before test_echo") async with WebSocketRpcClient( uri, RpcUtilityMethods(), default_response_timeout=4, serializing_socket_cls=BinarySerializingWebSocket, ) as client: + logger.debug("Initialized WebSocketRpcClient") text = "Hello World!" + logger.debug("Waiting for response...") response = await client.other.echo(text=text) + logger.debug("Response: %s", str(response)) assert response.result == text diff --git a/tests/custom_methods_test.py b/tests/custom_methods_test.py index c8937b9..013cb05 100644 --- a/tests/custom_methods_test.py +++ b/tests/custom_methods_test.py @@ -12,10 +12,6 @@ import os import sys -# Add parent path to use local src as package for tests -sys.path.append(os.path.abspath(os.path.join( - os.path.dirname(__file__), os.path.pardir))) - # Configurable PORT = int(os.environ.get("PORT") or "9000") diff --git a/tests/fast_api_depends_test.py b/tests/fast_api_depends_test.py index 07baf0a..8b6868d 100644 --- a/tests/fast_api_depends_test.py +++ b/tests/fast_api_depends_test.py @@ -3,10 +3,6 @@ from websockets.exceptions import InvalidStatusCode -# Add parent path to use local src as package for tests -sys.path.append(os.path.abspath(os.path.join( - os.path.dirname(__file__), os.path.pardir))) - from multiprocessing import Process import pytest diff --git a/tests/trigger_flow_test.py b/tests/trigger_flow_test.py index 4eb1fe7..6c0cfd7 100644 --- a/tests/trigger_flow_test.py +++ b/tests/trigger_flow_test.py @@ -5,10 +5,6 @@ import os import sys -# Add parent path to use local src as package for tests -sys.path.append(os.path.abspath(os.path.join( - os.path.dirname(__file__), os.path.pardir))) - from fastapi_websocket_rpc.utils import gen_uid from fastapi_websocket_rpc.websocket_rpc_endpoint import WebsocketRPCEndpoint from fastapi_websocket_rpc.websocket_rpc_client import WebSocketRpcClient