Skip to content

Commit

Permalink
Reverse proxy enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavsingh committed Apr 23, 2024
1 parent 380e0cc commit 1311b49
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 67 deletions.
19 changes: 12 additions & 7 deletions proxy/core/acceptor/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,17 @@ def setup(self) -> None:
"""Setup acceptors."""
self._start()
execution_mode = (
'threadless (local)'
if self.flags.local_executor
else 'threadless (remote)'
) if self.flags.threadless else 'threaded'
logger.info(
'Started %d acceptors in %s mode' % (
(
"threadless (local)"
if self.flags.local_executor
else "threadless (remote)"
)
if self.flags.threadless
else "threaded"
)
logger.debug(
"Started %d acceptors in %s mode"
% (
self.flags.num_acceptors,
execution_mode,
),
Expand All @@ -122,7 +127,7 @@ def setup(self) -> None:
self.fd_queues[index].close()

def shutdown(self) -> None:
logger.info('Shutting down %d acceptors' % self.flags.num_acceptors)
logger.debug("Shutting down %d acceptors" % self.flags.num_acceptors)
for acceptor in self.acceptors:
acceptor.running.set()
for acceptor in self.acceptors:
Expand Down
5 changes: 2 additions & 3 deletions proxy/core/listener/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def listen(self) -> socket.socket:
sock.listen(self.flags.backlog)
sock.setblocking(False)
self._port = sock.getsockname()[1]
logger.info(
'Listening on %s:%s' %
(self.hostname, self._port),
logger.debug(
"Listening on %s:%s" % (self.hostname, self._port),
)
return sock
41 changes: 37 additions & 4 deletions proxy/http/server/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from ..descriptors import DescriptorsHandlerMixin
from ...common.types import RePattern
from ...common.utils import bytes_
from ...http.server.protocols import httpProtocolTypes


if TYPE_CHECKING: # pragma: no cover
from ...core.connection import UpstreamConnectionPool
from ...core.connection import TcpServerConnection, UpstreamConnectionPool


class HttpWebServerBasePlugin(DescriptorsHandlerMixin, ABC):
Expand Down Expand Up @@ -64,7 +65,7 @@ def serve_static_file(path: str, min_compression_length: int) -> memoryview:
# TODO: Should we really close or take advantage of keep-alive?
conn_close=True,
)
except FileNotFoundError:
except OSError:
return NOT_FOUND_RESPONSE_PKT

def name(self) -> str:
Expand All @@ -88,6 +89,17 @@ def on_client_connection_close(self) -> None:
"""Client has closed the connection, do any clean up task now."""
pass

def do_upgrade(self, request: HttpParser) -> bool:
return True

def on_client_data(
self,
request: HttpParser,
raw: memoryview,
) -> Optional[memoryview]:
"""Return None to avoid default webserver parsing of client data."""
return raw

# No longer abstract since v2.4.0
#
# @abstractmethod
Expand Down Expand Up @@ -125,7 +137,7 @@ def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
return context


class ReverseProxyBasePlugin(ABC):
class ReverseProxyBasePlugin(DescriptorsHandlerMixin, ABC):
"""ReverseProxy base plugin class."""

def __init__(
Expand Down Expand Up @@ -161,13 +173,24 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]:
must return the url to serve."""
raise NotImplementedError() # pragma: no cover

def protocols(self) -> List[int]:
return [
httpProtocolTypes.HTTP,
httpProtocolTypes.HTTPS,
httpProtocolTypes.WEBSOCKET,
]

def before_routing(self, request: HttpParser) -> Optional[HttpParser]:
"""Plugins can modify request, return response, close connection.
If None is returned, request will be dropped and closed."""
return request # pragma: no cover

def handle_route(self, request: HttpParser, pattern: RePattern) -> Url:
def handle_route(
self,
request: HttpParser,
pattern: RePattern,
) -> Union[memoryview, Url, 'TcpServerConnection']:
"""Implement this method if you have configured dynamic routes."""
raise NotImplementedError()

Expand All @@ -182,3 +205,13 @@ def regexes(self) -> List[str]:
else:
raise ValueError('Invalid route type')
return routes

def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Use this method to override default access log format (see
DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT) or to add/update/modify passed context
for usage by default access logger.
Return updated log context to use for default logging format, OR
Return None if plugin has logged the request.
"""
return context
132 changes: 101 additions & 31 deletions proxy/http/server/reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
from proxy.http import Url
from proxy.core.base import TcpUpstreamConnectionHandler
from proxy.http.parser import HttpParser
from proxy.http.server import HttpWebServerBasePlugin, httpProtocolTypes
from proxy.http.server import HttpWebServerBasePlugin
from proxy.common.utils import text_
from proxy.http.exception import HttpProtocolException
from proxy.common.constants import (
HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT,
DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT,
)
from ...common.types import Readables, Writables, Descriptors


if TYPE_CHECKING: # pragma: no cover
Expand All @@ -44,6 +45,11 @@ def __init__(self, *args: Any, **kwargs: Any):
self.uid, self.flags, self.client, self.event_queue, self.upstream_conn_pool,
)
self.plugins.append(plugin)
self._upstream_proxy_pass: Optional[str] = None

def do_upgrade(self, request: HttpParser) -> bool:
"""Signal web protocol handler to not upgrade websocket requests by default."""
return False

def handle_upstream_data(self, raw: memoryview) -> None:
# TODO: Parse response and implement plugin hook per parsed response object
Expand All @@ -54,8 +60,8 @@ def routes(self) -> List[Tuple[int, str]]:
r = []
for plugin in self.plugins:
for route in plugin.regexes():
r.append((httpProtocolTypes.HTTP, route))
r.append((httpProtocolTypes.HTTPS, route))
for proto in plugin.protocols():
r.append((proto, route))
return r

def handle_request(self, request: HttpParser) -> None:
Expand All @@ -66,59 +72,123 @@ def handle_request(self, request: HttpParser) -> None:
raise HttpProtocolException('before_routing closed connection')
request = r

needs_upstream = False

# routes
for plugin in self.plugins:
for route in plugin.routes():
# Static routes
if isinstance(route, tuple):
pattern = re.compile(route[0])
if pattern.match(text_(request.path)):
self.choice = Url.from_bytes(
random.choice(route[1]),
)
break
# Dynamic routes
elif isinstance(route, str):
pattern = re.compile(route)
if pattern.match(text_(request.path)):
self.choice = plugin.handle_route(request, pattern)
choice = plugin.handle_route(request, pattern)
if isinstance(choice, Url):
self.choice = choice
needs_upstream = True
self._upstream_proxy_pass = str(self.choice)
elif isinstance(choice, memoryview):
self.client.queue(choice)
self._upstream_proxy_pass = '{0} bytes'.format(len(choice))
else:
self.upstream = choice
self._upstream_proxy_pass = '{0}:{1}'.format(
*self.upstream.addr,
)
break
else:
raise ValueError('Invalid route')

assert self.choice and self.choice.hostname
port = self.choice.port or \
DEFAULT_HTTP_PORT \
if self.choice.scheme == b'http' \
else DEFAULT_HTTPS_PORT
self.initialize_upstream(text_(self.choice.hostname), port)
assert self.upstream
try:
self.upstream.connect()
if self.choice.scheme == HTTPS_PROTO:
self.upstream.wrap(
text_(
self.choice.hostname,
if needs_upstream:
assert self.choice and self.choice.hostname
port = (
self.choice.port or DEFAULT_HTTP_PORT
if self.choice.scheme == b'http'
else DEFAULT_HTTPS_PORT
)
self.initialize_upstream(text_(self.choice.hostname), port)
assert self.upstream
try:
self.upstream.connect()
if self.choice.scheme == HTTPS_PROTO:
self.upstream.wrap(
text_(
self.choice.hostname,
),
as_non_blocking=True,
ca_file=self.flags.ca_file,
)
request.path = self.choice.remainder
self.upstream.queue(memoryview(request.build()))
except ConnectionRefusedError:
raise HttpProtocolException( # pragma: no cover
'Connection refused by upstream server {0}:{1}'.format(
text_(self.choice.hostname),
port,
),
as_non_blocking=True,
ca_file=self.flags.ca_file,
)
request.path = self.choice.remainder
self.upstream.queue(memoryview(request.build()))
except ConnectionRefusedError:
raise HttpProtocolException( # pragma: no cover
'Connection refused by upstream server {0}:{1}'.format(
text_(self.choice.hostname), port,
),
)

def on_client_connection_close(self) -> None:
if self.upstream and not self.upstream.closed:
logger.debug('Closing upstream server connection')
self.upstream.close()
self.upstream = None

def on_client_data(
self,
request: HttpParser,
raw: memoryview,
) -> Optional[memoryview]:
if request.is_websocket_upgrade:
assert self.upstream
self.upstream.queue(raw)
return raw

def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
context.update({
'upstream_proxy_pass': str(self.choice) if self.choice else None,
})
logger.info(DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT.format_map(context))
context.update(
{
'upstream_proxy_pass': self._upstream_proxy_pass,
},
)
log_handled = False
for plugin in self.plugins:
ctx = plugin.on_access_log(context)
if ctx is None:
log_handled = True
break
context = ctx
if not log_handled:
logger.info(DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT.format_map(context))
return None

async def get_descriptors(self) -> Descriptors:
r, w = await super().get_descriptors()
# TODO(abhinavsingh): We need to keep a mapping of plugin and
# descriptors registered by them, so that within write/read blocks
# we can invoke the right plugin callbacks.
for plugin in self.plugins:
plugin_read_desc, plugin_write_desc = await plugin.get_descriptors()
r.extend(plugin_read_desc)
w.extend(plugin_write_desc)
return r, w

async def read_from_descriptors(self, r: Readables) -> bool:
for plugin in self.plugins:
teardown = await plugin.read_from_descriptors(r)
if teardown:
return True
return await super().read_from_descriptors(r)

async def write_to_descriptors(self, w: Writables) -> bool:
for plugin in self.plugins:
teardown = await plugin.write_to_descriptors(w)
if teardown:
return True
return await super().write_to_descriptors(w)
Loading

0 comments on commit 1311b49

Please sign in to comment.