diff --git a/daphne/testing.py b/daphne/testing.py index 2e604554..ab5729e2 100644 --- a/daphne/testing.py +++ b/daphne/testing.py @@ -186,7 +186,7 @@ def run(self): application=application, endpoints=endpoints, signal_handlers=False, - **self.kwargs + **self.kwargs, ) # Set up a poller to look for the port reactor.callLater(0.1, self.resolve_port) diff --git a/setup.cfg b/setup.cfg index 9c0374ba..e8f4d5e8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,7 @@ console_scripts = [options.extras_require] tests = django + httpunixsocketconnection hypothesis pytest pytest-asyncio diff --git a/tests/http_base.py b/tests/http_base.py index e5a80c21..8d483b8c 100644 --- a/tests/http_base.py +++ b/tests/http_base.py @@ -17,6 +17,20 @@ class DaphneTestCase(unittest.TestCase): to store/retrieve the request/response messages. """ + _instance_endpoint_args = {} + + @staticmethod + def _get_instance_raw_socket_connection(test_app, *, timeout): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(timeout) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.connect((test_app.host, test_app.port)) + return s + + @staticmethod + def _get_instance_http_connection(test_app, *, timeout): + return HTTPConnection(test_app.host, test_app.port, timeout=timeout) + ### Plain HTTP helpers def run_daphne_http( @@ -36,13 +50,15 @@ def run_daphne_http( and response messages. """ with DaphneTestingInstance( - xff=xff, request_buffer_size=request_buffer_size + xff=xff, + request_buffer_size=request_buffer_size, + **self._instance_endpoint_args, ) as test_app: # Add the response messages test_app.add_send_messages(responses) # Send it the request. We have to do this the long way to allow # duplicate headers. - conn = HTTPConnection(test_app.host, test_app.port, timeout=timeout) + conn = self._get_instance_http_connection(test_app, timeout=timeout) if params: path += "?" + parse.urlencode(params, doseq=True) conn.putrequest(method, path, skip_accept_encoding=True, skip_host=True) @@ -74,13 +90,10 @@ def run_daphne_raw(self, data, *, responses=None, timeout=1): Returns what Daphne sends back. """ assert isinstance(data, bytes) - with DaphneTestingInstance() as test_app: + with DaphneTestingInstance(**self._instance_endpoint_args) as test_app: if responses is not None: test_app.add_send_messages(responses) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.settimeout(timeout) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.connect((test_app.host, test_app.port)) + s = self._get_instance_raw_socket_connection(test_app, timeout=timeout) s.send(data) try: return s.recv(1000000) diff --git a/tests/test_unixsocket.py b/tests/test_unixsocket.py new file mode 100644 index 00000000..4f991fde --- /dev/null +++ b/tests/test_unixsocket.py @@ -0,0 +1,52 @@ +import os +import socket +import weakref +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import skipUnless + +import test_http_response +from http_base import DaphneTestCase + +from httpunixsocketconnection import HTTPUnixSocketConnection + + +__all__ = ["UnixSocketFDDaphneTestCase", "TestInheritedUnixSocket"] + + +class UnixSocketFDDaphneTestCase(DaphneTestCase): + @property + def _instance_endpoint_args(self): + tmp_dir = TemporaryDirectory() + weakref.finalize(self, tmp_dir.cleanup) + sock_path = str(Path(tmp_dir.name, "test.sock")) + listen_sock = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) + listen_sock.bind(sock_path) + listen_sock.listen() + listen_sock_fileno = os.dup(listen_sock.fileno()) + os.set_inheritable(listen_sock_fileno, True) + listen_sock.close() + return {"host": None, "file_descriptor": listen_sock_fileno} + + @staticmethod + def _get_instance_socket_path(test_app): + with socket.socket(fileno=os.dup(test_app.file_descriptor)) as sock: + return sock.getsockname() + + @classmethod + def _get_instance_raw_socket_connection(cls, test_app, *, timeout): + socket_name = cls._get_instance_socket_path(test_app) + s = socket.socket(socket.AF_UNIX, type=socket.SOCK_STREAM) + s.settimeout(timeout) + s.connect(socket_name) + return s + + @classmethod + def _get_instance_http_connection(cls, test_app, *, timeout): + socket_name = cls._get_instance_socket_path(test_app) + return HTTPUnixSocketConnection(unix_socket=socket_name, timeout=timeout) + + +@skipUnless(hasattr(socket, "AF_UNIX"), "AF_UNIX support not present.") +class TestInheritedUnixSocket(UnixSocketFDDaphneTestCase): + test_minimal_response = test_http_response.TestHTTPResponse.test_minimal_response