From c9c06e6b81e93e5d899ada3d4c1c24b8fcacdee4 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Sat, 19 Oct 2024 18:12:42 +0530 Subject: [PATCH] Test cases for buffer flush --- proxy/http/handler.py | 2 +- proxy/plugin/reverse_proxy.py | 5 +- tests/http/web/test_web_server.py | 90 ++++++++++++++++++++++++++++++- 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/proxy/http/handler.py b/proxy/http/handler.py index 8a515efef8..d8f892da78 100644 --- a/proxy/http/handler.py +++ b/proxy/http/handler.py @@ -89,7 +89,7 @@ def shutdown(self) -> None: if self.plugin: self.plugin.on_client_connection_close() logger.debug( - "Closing client connection %s has buffer %s" + 'Closing client connection %s has buffer %s' % (self.work.address, self.work.has_buffer()), ) conn = self.work.connection diff --git a/proxy/plugin/reverse_proxy.py b/proxy/plugin/reverse_proxy.py index 7b7a5a4b38..d81eabf0d1 100644 --- a/proxy/plugin/reverse_proxy.py +++ b/proxy/plugin/reverse_proxy.py @@ -45,7 +45,10 @@ def routes(self) -> List[Union[str, Tuple[str, List[bytes]]]]: # A static route ( r'/get$', - [b'http://httpbingo.org/get', b'https://httpbingo.org/get'], + [ + b'http://httpbingo.org/get', + b'https://httpbingo.org/get', + ], ), # A dynamic route to catch requests on "/get/"" # See "handle_route" method below for what we do when diff --git a/tests/http/web/test_web_server.py b/tests/http/web/test_web_server.py index 8100d995bd..11a025deb0 100644 --- a/tests/http/web/test_web_server.py +++ b/tests/http/web/test_web_server.py @@ -12,21 +12,25 @@ import gzip import tempfile import selectors -from typing import Any +from typing import Any, cast import pytest +from unittest import mock from pytest_mock import MockerFixture from proxy.http import HttpProtocolHandler, HttpClientConnection +from proxy.http.url import Url from proxy.common.flag import FlagParser from proxy.http.parser import HttpParser, httpParserTypes, httpParserStates from proxy.common.utils import bytes_, build_http_request, build_http_response from proxy.common.plugins import Plugins from proxy.http.responses import NOT_FOUND_RESPONSE_PKT +from proxy.http.server.web import HttpWebServerPlugin from proxy.common.constants import ( CRLF, PROXY_PY_DIR, PLUGIN_PAC_FILE, PLUGIN_HTTP_PROXY, PLUGIN_WEB_SERVER, ) +from proxy.http.server.reverse import ReverseProxy from ...test_assertions import Assertions @@ -384,3 +388,87 @@ async def test_default_web_server_returns_404(self) -> None: self.protocol_handler.work.buffer[0], NOT_FOUND_RESPONSE_PKT, ) + + +class TestThreadedReverseProxyPlugin(Assertions): + + @pytest.fixture(autouse=True) # type: ignore[misc] + def _setUp(self, mocker: MockerFixture) -> None: + self.mock_socket = mocker.patch('socket.socket') + self.mock_socket_dup = mocker.patch('socket.dup', side_effect=lambda fd: fd) + self.mock_selector = mocker.patch('selectors.DefaultSelector') + self.fileno = 10 + self._addr = ('127.0.0.1', 54382) + self._conn = self.mock_socket.return_value + self.flags = FlagParser.initialize( + [ + '--enable-reverse-proxy', + ], + threaded=True, + plugins=[ + b'proxy.plugin.ReverseProxyPlugin', + ], + ) + self.protocol_handler = HttpProtocolHandler( + HttpClientConnection(self._conn, self._addr), + flags=self.flags, + ) + self.protocol_handler.initialize() + # Assert reverse proxy has loaded successfully + self.assertEqual( + self.protocol_handler.flags.plugins[b'HttpWebServerBasePlugin'][0].__name__, + 'ReverseProxy', + ) + # Assert reverse proxy plugins have loaded successfully + self.assertEqual( + self.protocol_handler.flags.plugins[b'ReverseProxyBasePlugin'][0].__name__, + 'ReverseProxyPlugin', + ) + + @pytest.mark.asyncio # type: ignore[misc] + @mock.patch('proxy.core.connection.server.ssl.create_default_context') + async def test_reverse_proxy_works( + self, + mock_create_default_context: mock.MagicMock, + ) -> None: + self.mock_selector.return_value.select.return_value = [ + ( + selectors.SelectorKey( + fileobj=self._conn.fileno(), + fd=self._conn.fileno(), + events=selectors.EVENT_READ, + data=None, + ), + selectors.EVENT_READ, + ), + ] + self._conn.recv.return_value = CRLF.join( + [ + b'GET /get HTTP/1.1', + CRLF, + ], + ) + await self.protocol_handler._run_once() + self.assertEqual( + self.protocol_handler.request.state, + httpParserStates.COMPLETE, + ) + assert ( + self.protocol_handler.plugin is not None + and self.protocol_handler.plugin.__class__.__name__ == 'HttpWebServerPlugin' + ) + rproxy = cast( + ReverseProxy, + cast(HttpWebServerPlugin, self.protocol_handler.plugin).route, + ) + choice = str(cast(Url, rproxy.choice)) + options = ('http://httpbingo.org/get', 'https://httpbingo.org/get') + is_https = choice == options[1] + if is_https: + mock_create_default_context.assert_called_once() + self.assertEqual(choice in options, True) + upstream = rproxy.upstream + self.assertEqual(upstream.__class__.__name__, 'TcpServerConnection') + assert upstream + self.assertEqual(upstream.addr, ('httpbingo.org', 80 if not is_https else 443)) + self.assertEqual(upstream.has_buffer(), True)