From 201d02ce72d4a491e10d82e022c1e6bc93c6d68d Mon Sep 17 00:00:00 2001 From: Abhinav Singh <126065+abhinavsingh@users.noreply.github.com> Date: Fri, 9 Aug 2024 23:45:17 +0530 Subject: [PATCH] Deprecate usage of `ssl.wrap_socket` in favour of `SSLContext.wrap_socket` (#1443) * Remove use of ssl.wrap_socket ssl.wrap_socket() has been deprecated since Python 3.7, and isn't recommended for use, and further, has been removed in Python 3.12. ssl.SSLContext().wrap_socket() is the new path forward, so switch the one callsite and the two test cases to use it instead. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix `SSLContext.wrap_socket` params and reusable `DEFAULT_SSL_CONTEXT_OPTIONS` * Fix test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix e2e tests??? --------- Co-authored-by: Steve Kowalik Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- proxy/common/constants.py | 4 ++++ proxy/common/utils.py | 19 ++++----------- proxy/core/connection/client.py | 23 +++++++++++++------ .../proxy/test_http_proxy_tls_interception.py | 17 +++++++++----- ...ttp_proxy_plugins_with_tls_interception.py | 4 ++-- 5 files changed, 38 insertions(+), 29 deletions(-) diff --git a/proxy/common/constants.py b/proxy/common/constants.py index 673f9a903c..46558227de 100644 --- a/proxy/common/constants.py +++ b/proxy/common/constants.py @@ -9,6 +9,7 @@ :license: BSD, see LICENSE for more details. """ import os +import ssl import sys import time import pathlib @@ -156,6 +157,9 @@ def _env_threadless_compliant() -> bool: DEFAULT_SELECTOR_SELECT_TIMEOUT = 25 / 1000 DEFAULT_WAIT_FOR_TASKS_TIMEOUT = 1 / 1000 DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT = 1 # in seconds +DEFAULT_SSL_CONTEXT_OPTIONS = ( + ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 +) DEFAULT_DEVTOOLS_DOC_URL = 'http://proxy' DEFAULT_DEVTOOLS_FRAME_ID = secrets.token_hex(8) diff --git a/proxy/common/utils.py b/proxy/common/utils.py index 4e2eed8144..9ac8883d07 100644 --- a/proxy/common/utils.py +++ b/proxy/common/utils.py @@ -26,7 +26,7 @@ from .types import HostPort from .constants import ( CRLF, COLON, HTTP_1_1, IS_WINDOWS, WHITESPACE, DEFAULT_TIMEOUT, - DEFAULT_THREADLESS, PROXY_AGENT_HEADER_VALUE, + DEFAULT_THREADLESS, PROXY_AGENT_HEADER_VALUE, DEFAULT_SSL_CONTEXT_OPTIONS, ) @@ -219,20 +219,11 @@ def wrap_socket( cafile: Optional[str] = None, ) -> ssl.SSLSocket: """Use this to upgrade server_side socket to TLS.""" - ctx = ssl.create_default_context( - ssl.Purpose.CLIENT_AUTH, - cafile=cafile, - ) - ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 + ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=cafile) + ctx.options |= DEFAULT_SSL_CONTEXT_OPTIONS ctx.verify_mode = ssl.CERT_NONE - ctx.load_cert_chain( - certfile=certfile, - keyfile=keyfile, - ) - return ctx.wrap_socket( - conn, - server_side=True, - ) + ctx.load_cert_chain(certfile=certfile, keyfile=keyfile) + return ctx.wrap_socket(conn, server_side=True) def new_socket_connection( diff --git a/proxy/core/connection/client.py b/proxy/core/connection/client.py index f241c56a06..4c8ad97cdd 100644 --- a/proxy/core/connection/client.py +++ b/proxy/core/connection/client.py @@ -9,11 +9,12 @@ :license: BSD, see LICENSE for more details. """ import ssl -from typing import Optional +from typing import Any, Dict, Optional from .types import tcpConnectionTypes from .connection import TcpConnection, TcpConnectionUninitializedException from ...common.types import HostPort, TcpOrTlsSocket +from ...common.constants import DEFAULT_SSL_CONTEXT_OPTIONS class TcpClientConnection(TcpConnection): @@ -42,11 +43,19 @@ def connection(self) -> TcpOrTlsSocket: def wrap(self, keyfile: str, certfile: str) -> None: self.connection.setblocking(True) self.flush() - self._conn = ssl.wrap_socket( - self.connection, - server_side=True, - certfile=certfile, - keyfile=keyfile, - ssl_version=ssl.PROTOCOL_TLS, + ctx = ssl.SSLContext( + protocol=( + ssl.PROTOCOL_TLS_CLIENT + if self.tag == 'server' + else ssl.PROTOCOL_TLS_SERVER + ), ) + ctx.options |= DEFAULT_SSL_CONTEXT_OPTIONS + ctx.load_cert_chain(certfile=certfile, keyfile=keyfile) + assert self.addr + kwargs: Dict[str, Any] = {'server_side': True} + if self.tag == 'server': + assert self.addr + kwargs['server_hostname'] = self.addr[0] + self._conn = ctx.wrap_socket(self.connection, **kwargs) self.connection.setblocking(False) diff --git a/tests/http/proxy/test_http_proxy_tls_interception.py b/tests/http/proxy/test_http_proxy_tls_interception.py index 654bbc5fcd..2fbdaef9a0 100644 --- a/tests/http/proxy/test_http_proxy_tls_interception.py +++ b/tests/http/proxy/test_http_proxy_tls_interception.py @@ -62,9 +62,9 @@ async def test_e2e(self, mocker: MockerFixture) -> None: self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock # Used for client wrapping - self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') + self.mock_ssl_wrap = mocker.patch('ssl.SSLContext') client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) - self.mock_ssl_wrap.return_value = client_tls_sock + self.mock_ssl_wrap.return_value.wrap_socket.return_value = client_tls_sock plain_connection = mock.MagicMock(spec=socket.socket) @@ -251,13 +251,18 @@ async def asyncReturn(val: T) -> T: ) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( - self._conn, - server_side=True, + protocol=ssl.PROTOCOL_TLS_SERVER, + ) + self.mock_ssl_wrap.return_value.load_cert_chain( keyfile=self.flags.ca_signing_key_file, certfile=HttpProxyPlugin.generated_cert_file_path( - self.flags.ca_cert_dir, host, + self.flags.ca_cert_dir, + host, ), - ssl_version=ssl.PROTOCOL_TLS, + ) + self.mock_ssl_wrap.return_value.wrap_socket.assert_called_with( + self._conn, + server_side=True, ) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( diff --git a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py index 3d8d6a28f4..a0a05b61f8 100644 --- a/tests/plugin/test_http_proxy_plugins_with_tls_interception.py +++ b/tests/plugin/test_http_proxy_plugins_with_tls_interception.py @@ -46,7 +46,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_ssl_context = mocker.patch('ssl.create_default_context') - self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') + self.mock_ssl_wrap = mocker.patch('ssl.SSLContext') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True @@ -82,7 +82,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) - self.mock_ssl_wrap.return_value = self.client_ssl_connection + self.mock_ssl_wrap.return_value.wrap_socket.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called)