From 1311b494edc7066e3de013b818b94783c5bff369 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Tue, 23 Apr 2024 19:04:17 +0530 Subject: [PATCH 1/2] Reverse proxy enhancements --- proxy/core/acceptor/pool.py | 19 +++-- proxy/core/listener/tcp.py | 5 +- proxy/http/server/plugin.py | 41 +++++++++-- proxy/http/server/reverse.py | 132 ++++++++++++++++++++++++++-------- proxy/http/server/web.py | 52 ++++++++------ proxy/plugin/reverse_proxy.py | 12 +++- 6 files changed, 194 insertions(+), 67 deletions(-) diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 09fb9f447f..2c58dccf2f 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -98,12 +98,17 @@ def setup(self) -> None: """Setup acceptors.""" self._start() execution_mode = ( - 'threadless (local)' - if self.flags.local_executor - else 'threadless (remote)' - ) if self.flags.threadless else 'threaded' - logger.info( - 'Started %d acceptors in %s mode' % ( + ( + "threadless (local)" + if self.flags.local_executor + else "threadless (remote)" + ) + if self.flags.threadless + else "threaded" + ) + logger.debug( + "Started %d acceptors in %s mode" + % ( self.flags.num_acceptors, execution_mode, ), @@ -122,7 +127,7 @@ def setup(self) -> None: self.fd_queues[index].close() def shutdown(self) -> None: - logger.info('Shutting down %d acceptors' % self.flags.num_acceptors) + logger.debug("Shutting down %d acceptors" % self.flags.num_acceptors) for acceptor in self.acceptors: acceptor.running.set() for acceptor in self.acceptors: diff --git a/proxy/core/listener/tcp.py b/proxy/core/listener/tcp.py index b6dc15e8ef..37bf6164a3 100644 --- a/proxy/core/listener/tcp.py +++ b/proxy/core/listener/tcp.py @@ -92,8 +92,7 @@ def listen(self) -> socket.socket: sock.listen(self.flags.backlog) sock.setblocking(False) self._port = sock.getsockname()[1] - logger.info( - 'Listening on %s:%s' % - (self.hostname, self._port), + logger.debug( + "Listening on %s:%s" % (self.hostname, self._port), ) return sock diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index 48cc5eb2a7..720f6415b1 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -22,10 +22,11 @@ from ..descriptors import DescriptorsHandlerMixin from ...common.types import RePattern from ...common.utils import bytes_ +from ...http.server.protocols import httpProtocolTypes if TYPE_CHECKING: # pragma: no cover - from ...core.connection import UpstreamConnectionPool + from ...core.connection import TcpServerConnection, UpstreamConnectionPool class HttpWebServerBasePlugin(DescriptorsHandlerMixin, ABC): @@ -64,7 +65,7 @@ def serve_static_file(path: str, min_compression_length: int) -> memoryview: # TODO: Should we really close or take advantage of keep-alive? conn_close=True, ) - except FileNotFoundError: + except OSError: return NOT_FOUND_RESPONSE_PKT def name(self) -> str: @@ -88,6 +89,17 @@ def on_client_connection_close(self) -> None: """Client has closed the connection, do any clean up task now.""" pass + def do_upgrade(self, request: HttpParser) -> bool: + return True + + def on_client_data( + self, + request: HttpParser, + raw: memoryview, + ) -> Optional[memoryview]: + """Return None to avoid default webserver parsing of client data.""" + return raw + # No longer abstract since v2.4.0 # # @abstractmethod @@ -125,7 +137,7 @@ def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: return context -class ReverseProxyBasePlugin(ABC): +class ReverseProxyBasePlugin(DescriptorsHandlerMixin, ABC): """ReverseProxy base plugin class.""" def __init__( @@ -161,13 +173,24 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]: must return the url to serve.""" raise NotImplementedError() # pragma: no cover + def protocols(self) -> List[int]: + return [ + httpProtocolTypes.HTTP, + httpProtocolTypes.HTTPS, + httpProtocolTypes.WEBSOCKET, + ] + def before_routing(self, request: HttpParser) -> Optional[HttpParser]: """Plugins can modify request, return response, close connection. If None is returned, request will be dropped and closed.""" return request # pragma: no cover - def handle_route(self, request: HttpParser, pattern: RePattern) -> Url: + def handle_route( + self, + request: HttpParser, + pattern: RePattern, + ) -> Union[memoryview, Url, 'TcpServerConnection']: """Implement this method if you have configured dynamic routes.""" raise NotImplementedError() @@ -182,3 +205,13 @@ def regexes(self) -> List[str]: else: raise ValueError('Invalid route type') return routes + + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Use this method to override default access log format (see + DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT) or to add/update/modify passed context + for usage by default access logger. + + Return updated log context to use for default logging format, OR + Return None if plugin has logged the request. + """ + return context diff --git a/proxy/http/server/reverse.py b/proxy/http/server/reverse.py index c4cadf3e00..4d91bf3a0a 100644 --- a/proxy/http/server/reverse.py +++ b/proxy/http/server/reverse.py @@ -16,13 +16,14 @@ from proxy.http import Url from proxy.core.base import TcpUpstreamConnectionHandler from proxy.http.parser import HttpParser -from proxy.http.server import HttpWebServerBasePlugin, httpProtocolTypes +from proxy.http.server import HttpWebServerBasePlugin from proxy.common.utils import text_ from proxy.http.exception import HttpProtocolException from proxy.common.constants import ( HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT, DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT, ) +from ...common.types import Readables, Writables, Descriptors if TYPE_CHECKING: # pragma: no cover @@ -44,6 +45,11 @@ def __init__(self, *args: Any, **kwargs: Any): self.uid, self.flags, self.client, self.event_queue, self.upstream_conn_pool, ) self.plugins.append(plugin) + self._upstream_proxy_pass: Optional[str] = None + + def do_upgrade(self, request: HttpParser) -> bool: + """Signal web protocol handler to not upgrade websocket requests by default.""" + return False def handle_upstream_data(self, raw: memoryview) -> None: # TODO: Parse response and implement plugin hook per parsed response object @@ -54,8 +60,8 @@ def routes(self) -> List[Tuple[int, str]]: r = [] for plugin in self.plugins: for route in plugin.regexes(): - r.append((httpProtocolTypes.HTTP, route)) - r.append((httpProtocolTypes.HTTPS, route)) + for proto in plugin.protocols(): + r.append((proto, route)) return r def handle_request(self, request: HttpParser) -> None: @@ -66,9 +72,12 @@ def handle_request(self, request: HttpParser) -> None: raise HttpProtocolException('before_routing closed connection') request = r + needs_upstream = False + # routes for plugin in self.plugins: for route in plugin.routes(): + # Static routes if isinstance(route, tuple): pattern = re.compile(route[0]) if pattern.match(text_(request.path)): @@ -76,39 +85,55 @@ def handle_request(self, request: HttpParser) -> None: random.choice(route[1]), ) break + # Dynamic routes elif isinstance(route, str): pattern = re.compile(route) if pattern.match(text_(request.path)): - self.choice = plugin.handle_route(request, pattern) + choice = plugin.handle_route(request, pattern) + if isinstance(choice, Url): + self.choice = choice + needs_upstream = True + self._upstream_proxy_pass = str(self.choice) + elif isinstance(choice, memoryview): + self.client.queue(choice) + self._upstream_proxy_pass = '{0} bytes'.format(len(choice)) + else: + self.upstream = choice + self._upstream_proxy_pass = '{0}:{1}'.format( + *self.upstream.addr, + ) break else: raise ValueError('Invalid route') - assert self.choice and self.choice.hostname - port = self.choice.port or \ - DEFAULT_HTTP_PORT \ - if self.choice.scheme == b'http' \ - else DEFAULT_HTTPS_PORT - self.initialize_upstream(text_(self.choice.hostname), port) - assert self.upstream - try: - self.upstream.connect() - if self.choice.scheme == HTTPS_PROTO: - self.upstream.wrap( - text_( - self.choice.hostname, + if needs_upstream: + assert self.choice and self.choice.hostname + port = ( + self.choice.port or DEFAULT_HTTP_PORT + if self.choice.scheme == b'http' + else DEFAULT_HTTPS_PORT + ) + self.initialize_upstream(text_(self.choice.hostname), port) + assert self.upstream + try: + self.upstream.connect() + if self.choice.scheme == HTTPS_PROTO: + self.upstream.wrap( + text_( + self.choice.hostname, + ), + as_non_blocking=True, + ca_file=self.flags.ca_file, + ) + request.path = self.choice.remainder + self.upstream.queue(memoryview(request.build())) + except ConnectionRefusedError: + raise HttpProtocolException( # pragma: no cover + 'Connection refused by upstream server {0}:{1}'.format( + text_(self.choice.hostname), + port, ), - as_non_blocking=True, - ca_file=self.flags.ca_file, ) - request.path = self.choice.remainder - self.upstream.queue(memoryview(request.build())) - except ConnectionRefusedError: - raise HttpProtocolException( # pragma: no cover - 'Connection refused by upstream server {0}:{1}'.format( - text_(self.choice.hostname), port, - ), - ) def on_client_connection_close(self) -> None: if self.upstream and not self.upstream.closed: @@ -116,9 +141,54 @@ def on_client_connection_close(self) -> None: self.upstream.close() self.upstream = None + def on_client_data( + self, + request: HttpParser, + raw: memoryview, + ) -> Optional[memoryview]: + if request.is_websocket_upgrade: + assert self.upstream + self.upstream.queue(raw) + return raw + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: - context.update({ - 'upstream_proxy_pass': str(self.choice) if self.choice else None, - }) - logger.info(DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT.format_map(context)) + context.update( + { + 'upstream_proxy_pass': self._upstream_proxy_pass, + }, + ) + log_handled = False + for plugin in self.plugins: + ctx = plugin.on_access_log(context) + if ctx is None: + log_handled = True + break + context = ctx + if not log_handled: + logger.info(DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT.format_map(context)) return None + + async def get_descriptors(self) -> Descriptors: + r, w = await super().get_descriptors() + # TODO(abhinavsingh): We need to keep a mapping of plugin and + # descriptors registered by them, so that within write/read blocks + # we can invoke the right plugin callbacks. + for plugin in self.plugins: + plugin_read_desc, plugin_write_desc = await plugin.get_descriptors() + r.extend(plugin_read_desc) + w.extend(plugin_write_desc) + return r, w + + async def read_from_descriptors(self, r: Readables) -> bool: + for plugin in self.plugins: + teardown = await plugin.read_from_descriptors(r) + if teardown: + return True + return await super().read_from_descriptors(r) + + async def write_to_descriptors(self, w: Writables) -> bool: + for plugin in self.plugins: + teardown = await plugin.write_to_descriptors(w) + if teardown: + return True + return await super().write_to_descriptors(w) diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index 06072493b2..f756494380 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -101,6 +101,9 @@ def __init__( if b'HttpWebServerBasePlugin' in self.flags.plugins: self._initialize_web_plugins() + self._response_size = 0 + self._post_request_data_size = 0 + @staticmethod def protocols() -> List[int]: return [httpProtocols.WEB_SERVER] @@ -138,17 +141,17 @@ def switch_to_websocket(self) -> None: def on_request_complete(self) -> Union[socket.socket, bool]: path = self.request.path or b'/' teardown = self._try_route(path) - # Try route signaled to teardown - # or if it did find a valid route - if teardown or self.route is not None: + if teardown: return teardown # No-route found, try static serving if enabled - if self.flags.enable_static_server: - self._try_static_or_404(path) + if self.route is None: + if self.flags.enable_static_server: + self._try_static_or_404(path) + return True + # Catch all unhandled web server requests, return 404 + self.client.queue(NOT_FOUND_RESPONSE_PKT) return True - # Catch all unhandled web server requests, return 404 - self.client.queue(NOT_FOUND_RESPONSE_PKT) - return True + return False async def get_descriptors(self) -> Descriptors: r, w = [], [] @@ -173,6 +176,9 @@ async def read_from_descriptors(self, r: Readables) -> bool: return False def on_client_data(self, raw: memoryview) -> None: + self._post_request_data_size += len(raw) + if self.route and self.route.on_client_data(self.request, raw) is None: + return if self.switched_protocol == httpProtocolTypes.WEBSOCKET: # TODO(abhinavsingh): Do we really tobytes() here? # Websocket parser currently doesn't depend on internal @@ -211,6 +217,7 @@ def on_client_data(self, raw: memoryview) -> None: self.pipeline_request = None def on_response_chunk(self, chunk: List[memoryview]) -> List[memoryview]: + self._response_size += sum([len(c) for c in chunk]) return chunk def on_client_connection_close(self) -> None: @@ -221,11 +228,15 @@ def on_client_connection_close(self) -> None: # Request 'request_method': text_(self.request.method), 'request_path': text_(self.request.path), - 'request_bytes': self.request.total_size, - 'request_ua': text_(self.request.header(b'user-agent')) - if self.request.has_header(b'user-agent') - else None, - 'request_version': None if not self.request.version else text_(self.request.version), + 'request_bytes': self.request.total_size + self._post_request_data_size, + 'request_ua': ( + text_(self.request.header(b'user-agent')) + if self.request.has_header(b'user-agent') + else None + ), + 'request_version': ( + None if not self.request.version else text_(self.request.version) + ), # Response # # TODO: Track and inject web server specific response attributes @@ -234,7 +245,7 @@ def on_client_connection_close(self) -> None: # several attributes required below. At least for code and # reason attributes. # - # 'response_bytes': self.response.total_size, + 'response_bytes': self._response_size, # 'response_code': text_(self.response.code), # 'response_reason': text_(self.response.reason), } @@ -256,8 +267,7 @@ def access_log(self, context: Dict[str, Any]) -> None: @property def _protocol(self) -> Tuple[bool, int]: - do_ws_upgrade = self.request.is_connection_upgrade and \ - self.request.header(b'upgrade').lower() == b'websocket' + do_ws_upgrade = self.request.is_websocket_upgrade return do_ws_upgrade, httpProtocolTypes.WEBSOCKET \ if do_ws_upgrade \ else httpProtocolTypes.HTTPS \ @@ -271,7 +281,7 @@ def _try_route(self, path: bytes) -> bool: self.route = self.routes[protocol][route] assert self.route # Optionally, upgrade protocol - if do_ws_upgrade: + if do_ws_upgrade and self.route.do_upgrade(self.request): self.switch_to_websocket() assert self.route # Invoke plugin.on_websocket_open @@ -279,9 +289,11 @@ def _try_route(self, path: bytes) -> bool: else: # Invoke plugin.handle_request self.route.handle_request(self.request) - if self.request.has_header(b'connection') and \ - self.request.header(b'connection').lower() == b'close': - return True + # if self.request.has_header(b'connection') and \ + # self.request.header(b'connection').lower() == b'close': + # return True + # Bailout on first match + break return False def _try_static_or_404(self, path: bytes) -> None: diff --git a/proxy/plugin/reverse_proxy.py b/proxy/plugin/reverse_proxy.py index fb96e15486..7b7a5a4b38 100644 --- a/proxy/plugin/reverse_proxy.py +++ b/proxy/plugin/reverse_proxy.py @@ -13,7 +13,7 @@ Lua """ import re -from typing import List, Tuple, Union +from typing import TYPE_CHECKING, List, Tuple, Union from ..http import Url from ..http.parser import HttpParser @@ -22,6 +22,10 @@ from ..http.exception.base import HttpProtocolException +if TYPE_CHECKING: + from ..core.connection import TcpServerConnection + + class ReverseProxyPlugin(ReverseProxyBasePlugin): """This example plugin is equivalent to following Nginx configuration:: @@ -49,7 +53,11 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]: r'/get/(\d+)$', ] - def handle_route(self, request: HttpParser, pattern: RePattern) -> Url: + def handle_route( + self, + request: HttpParser, + pattern: RePattern, + ) -> Union[memoryview, Url, 'TcpServerConnection']: """For our example dynamic route, we want to simply convert any incoming request to "/get/1" into "/get?id=1" when serving from upstream. """ From cc5d5119b8cdcd1dcbcffe49a2306663d798dbf9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Apr 2024 13:35:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- proxy/core/acceptor/pool.py | 10 +++++----- proxy/core/listener/tcp.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/proxy/core/acceptor/pool.py b/proxy/core/acceptor/pool.py index 2c58dccf2f..5fbae26a3d 100644 --- a/proxy/core/acceptor/pool.py +++ b/proxy/core/acceptor/pool.py @@ -99,15 +99,15 @@ def setup(self) -> None: self._start() execution_mode = ( ( - "threadless (local)" + 'threadless (local)' if self.flags.local_executor - else "threadless (remote)" + else 'threadless (remote)' ) if self.flags.threadless - else "threaded" + else 'threaded' ) logger.debug( - "Started %d acceptors in %s mode" + 'Started %d acceptors in %s mode' % ( self.flags.num_acceptors, execution_mode, @@ -127,7 +127,7 @@ def setup(self) -> None: self.fd_queues[index].close() def shutdown(self) -> None: - logger.debug("Shutting down %d acceptors" % self.flags.num_acceptors) + logger.debug('Shutting down %d acceptors' % self.flags.num_acceptors) for acceptor in self.acceptors: acceptor.running.set() for acceptor in self.acceptors: diff --git a/proxy/core/listener/tcp.py b/proxy/core/listener/tcp.py index 37bf6164a3..ed041373e3 100644 --- a/proxy/core/listener/tcp.py +++ b/proxy/core/listener/tcp.py @@ -93,6 +93,6 @@ def listen(self) -> socket.socket: sock.setblocking(False) self._port = sock.getsockname()[1] logger.debug( - "Listening on %s:%s" % (self.hostname, self._port), + 'Listening on %s:%s' % (self.hostname, self._port), ) return sock