Skip to content

Commit

Permalink
Deprecate usage of ssl.wrap_socket in favour of `SSLContext.wrap_so…
Browse files Browse the repository at this point in the history
…cket` (#1443)

* Remove use of ssl.wrap_socket

ssl.wrap_socket() has been deprecated since Python 3.7, and isn't
recommended for use, and further, has been removed in Python 3.12.
ssl.SSLContext().wrap_socket() is the new path forward, so switch the
one callsite and the two test cases to use it instead.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix `SSLContext.wrap_socket` params and reusable `DEFAULT_SSL_CONTEXT_OPTIONS`

* Fix test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix e2e tests???

---------

Co-authored-by: Steve Kowalik <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent 6ac5aed commit 201d02c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 29 deletions.
4 changes: 4 additions & 0 deletions proxy/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
:license: BSD, see LICENSE for more details.
"""
import os
import ssl
import sys
import time
import pathlib
Expand Down Expand Up @@ -156,6 +157,9 @@ def _env_threadless_compliant() -> bool:
DEFAULT_SELECTOR_SELECT_TIMEOUT = 25 / 1000
DEFAULT_WAIT_FOR_TASKS_TIMEOUT = 1 / 1000
DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT = 1 # in seconds
DEFAULT_SSL_CONTEXT_OPTIONS = (
ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
)

DEFAULT_DEVTOOLS_DOC_URL = 'http://proxy'
DEFAULT_DEVTOOLS_FRAME_ID = secrets.token_hex(8)
Expand Down
19 changes: 5 additions & 14 deletions proxy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .types import HostPort
from .constants import (
CRLF, COLON, HTTP_1_1, IS_WINDOWS, WHITESPACE, DEFAULT_TIMEOUT,
DEFAULT_THREADLESS, PROXY_AGENT_HEADER_VALUE,
DEFAULT_THREADLESS, PROXY_AGENT_HEADER_VALUE, DEFAULT_SSL_CONTEXT_OPTIONS,
)


Expand Down Expand Up @@ -219,20 +219,11 @@ def wrap_socket(
cafile: Optional[str] = None,
) -> ssl.SSLSocket:
"""Use this to upgrade server_side socket to TLS."""
ctx = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH,
cafile=cafile,
)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=cafile)
ctx.options |= DEFAULT_SSL_CONTEXT_OPTIONS
ctx.verify_mode = ssl.CERT_NONE
ctx.load_cert_chain(
certfile=certfile,
keyfile=keyfile,
)
return ctx.wrap_socket(
conn,
server_side=True,
)
ctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
return ctx.wrap_socket(conn, server_side=True)


def new_socket_connection(
Expand Down
23 changes: 16 additions & 7 deletions proxy/core/connection/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
:license: BSD, see LICENSE for more details.
"""
import ssl
from typing import Optional
from typing import Any, Dict, Optional

from .types import tcpConnectionTypes
from .connection import TcpConnection, TcpConnectionUninitializedException
from ...common.types import HostPort, TcpOrTlsSocket
from ...common.constants import DEFAULT_SSL_CONTEXT_OPTIONS


class TcpClientConnection(TcpConnection):
Expand Down Expand Up @@ -42,11 +43,19 @@ def connection(self) -> TcpOrTlsSocket:
def wrap(self, keyfile: str, certfile: str) -> None:
self.connection.setblocking(True)
self.flush()
self._conn = ssl.wrap_socket(
self.connection,
server_side=True,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.PROTOCOL_TLS,
ctx = ssl.SSLContext(
protocol=(
ssl.PROTOCOL_TLS_CLIENT
if self.tag == 'server'
else ssl.PROTOCOL_TLS_SERVER
),
)
ctx.options |= DEFAULT_SSL_CONTEXT_OPTIONS
ctx.load_cert_chain(certfile=certfile, keyfile=keyfile)
assert self.addr
kwargs: Dict[str, Any] = {'server_side': True}
if self.tag == 'server':
assert self.addr
kwargs['server_hostname'] = self.addr[0]
self._conn = ctx.wrap_socket(self.connection, **kwargs)
self.connection.setblocking(False)
17 changes: 11 additions & 6 deletions tests/http/proxy/test_http_proxy_tls_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ async def test_e2e(self, mocker: MockerFixture) -> None:
self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock

# Used for client wrapping
self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket')
self.mock_ssl_wrap = mocker.patch('ssl.SSLContext')
client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket)
self.mock_ssl_wrap.return_value = client_tls_sock
self.mock_ssl_wrap.return_value.wrap_socket.return_value = client_tls_sock

plain_connection = mock.MagicMock(spec=socket.socket)

Expand Down Expand Up @@ -251,13 +251,18 @@ async def asyncReturn(val: T) -> T:
)
assert self.flags.ca_cert_dir is not None
self.mock_ssl_wrap.assert_called_with(
self._conn,
server_side=True,
protocol=ssl.PROTOCOL_TLS_SERVER,
)
self.mock_ssl_wrap.return_value.load_cert_chain(
keyfile=self.flags.ca_signing_key_file,
certfile=HttpProxyPlugin.generated_cert_file_path(
self.flags.ca_cert_dir, host,
self.flags.ca_cert_dir,
host,
),
ssl_version=ssl.PROTOCOL_TLS,
)
self.mock_ssl_wrap.return_value.wrap_socket.assert_called_with(
self._conn,
server_side=True,
)
self.assertEqual(self._conn.setblocking.call_count, 2)
self.assertEqual(
Expand Down
4 changes: 2 additions & 2 deletions tests/plugin/test_http_proxy_plugins_with_tls_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None:
'proxy.http.proxy.server.TcpServerConnection',
)
self.mock_ssl_context = mocker.patch('ssl.create_default_context')
self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket')
self.mock_ssl_wrap = mocker.patch('ssl.SSLContext')

self.mock_sign_csr.return_value = True
self.mock_gen_csr.return_value = True
Expand Down Expand Up @@ -82,7 +82,7 @@ def _setUp(self, request: Any, mocker: MockerFixture) -> None:
self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket)
self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection
self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket)
self.mock_ssl_wrap.return_value = self.client_ssl_connection
self.mock_ssl_wrap.return_value.wrap_socket.return_value = self.client_ssl_connection

def has_buffer() -> bool:
return cast(bool, self.server.queue.called)
Expand Down

0 comments on commit 201d02c

Please sign in to comment.