From afe2e62e0274c8d29a4f0d81d91d5148fccaf6b4 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Mon, 14 Oct 2024 19:33:19 +0530 Subject: [PATCH] Rewrite Host header during reverse proxy --- proxy/core/connection/connection.py | 6 +++--- proxy/http/parser/parser.py | 28 ++++++++++++++++++++++------ proxy/http/server/reverse.py | 23 ++++++++++++++++------- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/proxy/core/connection/connection.py b/proxy/core/connection/connection.py index 63cb62e316..2200f93e6c 100644 --- a/proxy/core/connection/connection.py +++ b/proxy/core/connection/connection.py @@ -49,7 +49,6 @@ def connection(self) -> TcpOrTlsSocket: def send(self, data: Union[memoryview, bytes]) -> int: """Users must handle BrokenPipeError exceptions""" - # logger.info(data.tobytes()) return self.connection.send(data) def recv( @@ -67,7 +66,7 @@ def recv( return memoryview(data) def close(self) -> bool: - if not self.closed: + if not self.closed and self.connection: self.connection.close() self.closed = True return self.closed @@ -97,8 +96,9 @@ def flush(self, max_send_size: Optional[int] = None) -> int: self._num_buffer -= 1 else: self.buffer[0] = mv[sent:] - del mv logger.debug('flushed %d bytes to %s' % (sent, self.tag)) + # logger.info(mv[:sent].tobytes()) + del mv return sent def is_reusable(self) -> bool: diff --git a/proxy/http/parser/parser.py b/proxy/http/parser/parser.py index c16f74e7c3..6e1e885962 100644 --- a/proxy/http/parser/parser.py +++ b/proxy/http/parser/parser.py @@ -283,7 +283,12 @@ def parse( self.state = httpParserStates.COMPLETE self.buffer = None if raw == b'' else raw - def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = False) -> bytes: + def build( + self, + disable_headers: Optional[List[bytes]] = None, + for_proxy: bool = False, + host: Optional[bytes] = None, + ) -> bytes: """Rebuild the request object.""" assert self.method and self.version and self.type == httpParserTypes.REQUEST_PARSER if disable_headers is None: @@ -301,11 +306,22 @@ def build(self, disable_headers: Optional[List[bytes]] = None, for_proxy: bool = path ) if not self._is_https_tunnel else (self.host + COLON + str(self.port).encode()) return build_http_request( - self.method, path, self.version, - headers={} if not self.headers else { - self.headers[k][0]: self.headers[k][1] for k in self.headers if - k.lower() not in disable_headers - }, + self.method, + path, + self.version, + headers=( + {} + if not self.headers + else { + self.headers[k][0]: ( + self.headers[k][1] + if host is None or self.headers[k][0].lower() != b'host' + else host + ) + for k in self.headers + if k.lower() not in disable_headers + } + ), body=body, no_ua=True, ) diff --git a/proxy/http/server/reverse.py b/proxy/http/server/reverse.py index 303b627f20..c349ebabdb 100644 --- a/proxy/http/server/reverse.py +++ b/proxy/http/server/reverse.py @@ -20,7 +20,7 @@ from proxy.common.utils import text_ from proxy.http.exception import HttpProtocolException from proxy.common.constants import ( - HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT, + COLON, HTTP_PROTO, HTTPS_PROTO, DEFAULT_HTTP_PORT, DEFAULT_HTTPS_PORT, DEFAULT_REVERSE_PROXY_ACCESS_LOG_FORMAT, ) from ...common.types import Readables, Writables, Descriptors @@ -111,8 +111,8 @@ def handle_request(self, request: HttpParser) -> None: assert self.choice and self.choice.hostname port = ( self.choice.port or DEFAULT_HTTP_PORT - if self.choice.scheme == b'http' - else DEFAULT_HTTPS_PORT + if self.choice.scheme == HTTP_PROTO + else self.choice.port or DEFAULT_HTTPS_PORT ) self.initialize_upstream(text_(self.choice.hostname), port) assert self.upstream @@ -120,14 +120,23 @@ def handle_request(self, request: HttpParser) -> None: self.upstream.connect() if self.choice.scheme == HTTPS_PROTO: self.upstream.wrap( - text_( - self.choice.hostname, - ), + text_(self.choice.hostname), as_non_blocking=True, ca_file=self.flags.ca_file, ) request.path = self.choice.remainder - self.upstream.queue(memoryview(request.build())) + self.upstream.queue( + memoryview( + request.build( + host=self.choice.hostname + + ( + COLON + self.choice.port.to_bytes() + if self.choice.port is not None + else b'' + ), + ), + ), + ) except ConnectionRefusedError: raise HttpProtocolException( # pragma: no cover 'Connection refused by upstream server {0}:{1}'.format(