diff --git a/.gitignore b/.gitignore index f4848296..6d39acd8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ dist/ .coverage .tox +# Tox tests leftovers +build + # Translations *.mo diff --git a/pylxd/client.py b/pylxd/client.py index b8b37dd8..804f15ab 100644 --- a/pylxd/client.py +++ b/pylxd/client.py @@ -14,20 +14,21 @@ import json import os import re +import socket from enum import Enum from typing import NamedTuple from urllib import parse import requests -import requests_unixsocket +import requests.adapters +import urllib3 +import urllib3.connection from cryptography import x509 from cryptography.hazmat.primitives import hashes from ws4py.client import WebSocketBaseClient from pylxd import exceptions, managers -requests_unixsocket.monkeypatch() - SNAP_ROOT = os.path.expanduser("~/snap/lxd/common/config/") APT_ROOT = os.path.expanduser("~/.config/lxc/") CERT_FILE_NAME = "client.crt" @@ -51,6 +52,9 @@ class Cert(NamedTuple): key=os.path.join(CERTS_PATH, KEY_FILE_NAME), ) # pragma: no cover +DEFAULT_SCHEME = "http+unix://" +SOCKET_CONNECTION_TIMEOUT = 60 + class EventType(Enum): All = "all" @@ -59,6 +63,65 @@ class EventType(Enum): Lifecycle = "lifecycle" +class _UnixSocketHTTPConnection(urllib3.connection.HTTPConnection, object): + def __init__(self, unix_socket_url): + super(_UnixSocketHTTPConnection, self).__init__( + "localhost", timeout=SOCKET_CONNECTION_TIMEOUT + ) + self.unix_socket_url = unix_socket_url + self.timeout = SOCKET_CONNECTION_TIMEOUT + self.sock = None + + def __del__(self): + if self.sock: + self.sock.close() + + def connect(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + socket_path = parse.unquote(parse.urlparse(self.unix_socket_url).netloc) + sock.connect(socket_path) + self.sock = sock + + +class _UnixSocketHTTPConnectionPool(urllib3.HTTPConnectionPool): + def __init__(self, socket_path): + super(_UnixSocketHTTPConnectionPool, self).__init__("localhost") + self.socket_path = socket_path + + def _new_conn(self): + return _UnixSocketHTTPConnection(self.socket_path) + + +class _UnixAdapter(requests.adapters.HTTPAdapter): + def __init__(self, pool_connections=25, *args, **kwargs): + super(_UnixAdapter, self).__init__(*args, **kwargs) + self.pools = urllib3._collections.RecentlyUsedContainer( + pool_connections, dispose_func=lambda p: p.close() + ) + + def get_connection(self, url, proxies): + with self.pools.lock: + conn = self.pools.get(url) + if conn: + return conn + + conn = _UnixSocketHTTPConnectionPool(url) + self.pools[url] = conn + + return conn + + # This method is needed fo compatibility with later requests versions. + def get_connection_with_tls_context(self, request, verify, proxies=None, cert=None): + return self.get_connection(request.url, None) + + def request_url(self, request, proxies): + return request.path_url + + def close(self): + self.pools.clear() + + class LXDSSLAdapter(requests.adapters.HTTPAdapter): def cert_verify(self, conn, url, verify, cert): with open(verify, "rb") as fd: @@ -74,11 +137,10 @@ def get_session_for_url(url: str, verify=None, cert=None) -> requests.Session: Call sites can use this to customise the session before passing into a Client. """ - session: requests.Session - if url.startswith("http+unix://"): - session = requests_unixsocket.Session() + session = requests.Session() + if url.startswith(DEFAULT_SCHEME): + session.mount(DEFAULT_SCHEME, _UnixAdapter()) else: - session = requests.Session() session.cert = cert session.verify = verify diff --git a/pylxd/tests/test_client.py b/pylxd/tests/test_client.py index 9d572cca..02d677d7 100644 --- a/pylxd/tests/test_client.py +++ b/pylxd/tests/test_client.py @@ -18,7 +18,7 @@ import pytest import requests -import requests_unixsocket +import requests.adapters from pylxd import client, exceptions @@ -631,13 +631,17 @@ class TestGetSessionForUrl(TestCase): def test_session_unix_socket(self): """http+unix URL return a requests_unixsocket session.""" session = client.get_session_for_url("http+unix://test.com") - self.assertIsInstance(session, requests_unixsocket.Session) + self.assertIsInstance( + session.get_adapter("http+unix://"), requests.adapters.HTTPAdapter + ) def test_session_http(self): """HTTP nodes return the default requests session.""" session = client.get_session_for_url("http://test.com") self.assertIsInstance(session, requests.Session) - self.assertNotIsInstance(session, requests_unixsocket.Session) + self.assertRaises( + requests.exceptions.InvalidSchema, session.get_adapter, "http+unix://" + ) def test_session_cert(self): """If certs are given, they're set on the Session.""" diff --git a/setup.cfg b/setup.cfg index 1b0aa295..08c68eef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,10 +20,8 @@ packages = find: install_requires = cryptography >= 3.2 python-dateutil >= 2.4.2 - requests >= 2.20.0, < 2.32.0 + requests >= 2.20.0 requests-toolbelt >= 0.8.0 - requests-unixsocket >= 0.1.5 - urllib3 < 2 ws4py != 0.3.5, >= 0.3.4 # 0.3.5 is broken for websocket support [options.extras_require]