diff --git a/CHANGELOG.md b/CHANGELOG.md index ecc2ec0..f2a9112 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Placeholder for upcoming changes. +## [1.0.1] - 2024-10-18 + +### Fixed + +- Improved performance of `Secure.set_headers` by reducing redundant type checks. ([#26](https://github.com/TypeError/secure/issues/26)) + ## [1.0.0] - 2024-09-27 ### Breaking Changes diff --git a/pyproject.toml b/pyproject.toml index 303ca0e..2efe9ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "secure" -version = "1.0.0" +version = "1.0.1" description = "A lightweight package that adds security headers for Python web frameworks." readme = { file = "README.md", "content-type" = "text/markdown" } license = { text = "MIT" } diff --git a/secure/secure.py b/secure/secure.py index a6d5787..2288bde 100644 --- a/secure/secure.py +++ b/secure/secure.py @@ -236,22 +236,23 @@ def set_headers(self, response: ResponseProtocol) -> None: RuntimeError: If an asynchronous 'set_header' method is used in a synchronous context. AttributeError: If the response object does not support setting headers. """ - for header_name, header_value in self.headers.items(): - if isinstance(response, SetHeaderProtocol): - # If response has set_header method, use it - set_header = response.set_header - if inspect.iscoroutinefunction(set_header): - raise RuntimeError( - "Encountered asynchronous 'set_header' in synchronous context." - ) + if isinstance(response, SetHeaderProtocol): + # Use the set_header method if available + set_header = response.set_header + if inspect.iscoroutinefunction(set_header): + raise RuntimeError( + "Encountered asynchronous 'set_header' in synchronous context." + ) + for header_name, header_value in self.headers.items(): set_header(header_name, header_value) - elif isinstance(response, HeadersProtocol): # type: ignore - # If response has headers dictionary, use it + elif isinstance(response, HeadersProtocol): # type: ignore + # Use the headers dictionary if available + for header_name, header_value in self.headers.items(): response.headers[header_name] = header_value - else: - raise AttributeError( - f"Response object of type '{type(response).__name__}' does not support setting headers." - ) + else: + raise AttributeError( + f"Response object of type '{type(response).__name__}' does not support setting headers." + ) async def set_headers_async(self, response: ResponseProtocol) -> None: """ @@ -266,18 +267,20 @@ async def set_headers_async(self, response: ResponseProtocol) -> None: Raises: AttributeError: If the response object does not support setting headers. """ - for header_name, header_value in self.headers.items(): - if isinstance(response, SetHeaderProtocol): - # If response has set_header method, use it - set_header = response.set_header - if inspect.iscoroutinefunction(set_header): + if isinstance(response, SetHeaderProtocol): + # Use the set_header method if available + set_header = response.set_header + if inspect.iscoroutinefunction(set_header): + for header_name, header_value in self.headers.items(): await set_header(header_name, header_value) - else: + else: + for header_name, header_value in self.headers.items(): set_header(header_name, header_value) - elif isinstance(response, HeadersProtocol): # type: ignore - # If response has headers dictionary, use it + elif isinstance(response, HeadersProtocol): # type: ignore + # Use the headers dictionary if available + for header_name, header_value in self.headers.items(): response.headers[header_name] = header_value - else: - raise AttributeError( - f"Response object of type '{type(response).__name__}' does not support setting headers." - ) + else: + raise AttributeError( + f"Response object of type '{type(response).__name__}' does not support setting headers." + ) diff --git a/tests/secure/test_secure.py b/tests/secure/test_secure.py index 4b77933..2184b76 100644 --- a/tests/secure/test_secure.py +++ b/tests/secure/test_secure.py @@ -1,3 +1,4 @@ +import asyncio import unittest from secure import ( @@ -44,6 +45,20 @@ class MockResponseNoHeaders: class TestSecure(unittest.TestCase): + def setUp(self): + # Initialize Secure with some test headers + self.secure = Secure( + custom=[ + CustomHeader("X-Test-Header-1", "Value1"), + CustomHeader("X-Test-Header-2", "Value2"), + ] + ) + # Precompute headers dictionary + self.secure.headers = { + header.header_name: header.header_value + for header in self.secure.headers_list + } + def test_with_default_headers(self): """Test that default headers are correctly applied.""" secure_headers = Secure.with_default_headers() @@ -210,8 +225,6 @@ def test_async_set_headers(self): async def mock_set_headers(): await secure_headers.set_headers_async(response) - import asyncio - asyncio.run(mock_set_headers()) # Verify that headers are set asynchronously @@ -235,43 +248,43 @@ async def mock_set_headers(): def test_set_headers_with_set_header_method(self): """Test setting headers on a response object with set_header method.""" - secure_headers = Secure.with_default_headers() response = MockResponseWithSetHeader() - - # Apply the headers to the response object - secure_headers.set_headers(response) + self.secure.set_headers(response) # Verify that headers are set using set_header method - self.assertIn("Strict-Transport-Security", response.header_storage) - self.assertEqual( - response.header_storage["Strict-Transport-Security"], - "max-age=31536000", - ) + self.assertEqual(response.header_storage, self.secure.headers) + # Ensure set_header was called correct number of times + self.assertEqual(len(response.header_storage), len(self.secure.headers)) - self.assertIn("X-Content-Type-Options", response.header_storage) - self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff") + def test_set_headers_with_headers_dict(self): + """Test set_headers with a response object that has a headers dictionary.""" + response = MockResponse() + self.secure.set_headers(response) - def test_async_set_headers_with_async_set_header_method(self): - """Test async setting headers on a response object with async set_header method.""" - secure_headers = Secure.with_default_headers() - response = MockResponseAsyncSetHeader() + # Verify that headers are set + self.assertEqual(response.headers, self.secure.headers) - async def mock_set_headers(): - await secure_headers.set_headers_async(response) + def test_set_headers_async_with_async_set_header(self): + """Test set_headers_async with a response object that has an asynchronous set_header method.""" + response = MockResponseAsyncSetHeader() - import asyncio + async def test_async(): + await self.secure.set_headers_async(response) - asyncio.run(mock_set_headers()) + asyncio.run(test_async()) # Verify that headers are set using async set_header method - self.assertIn("Strict-Transport-Security", response.header_storage) - self.assertEqual( - response.header_storage["Strict-Transport-Security"], - "max-age=31536000", - ) + self.assertEqual(response.header_storage, self.secure.headers) + # Ensure set_header was called correct number of times + self.assertEqual(len(response.header_storage), len(self.secure.headers)) + + def test_set_headers_async_with_headers_dict(self): + """Test set_headers_async with a response object that has a headers dictionary.""" + response = MockResponse() + asyncio.run(self.secure.set_headers_async(response)) - self.assertIn("X-Content-Type-Options", response.header_storage) - self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff") + # Verify that headers are set + self.assertEqual(response.headers, self.secure.headers) def test_set_headers_missing_interface(self): """Test that an error is raised when response object lacks required methods.""" @@ -286,6 +299,12 @@ def test_set_headers_missing_interface(self): str(context.exception), ) + def test_set_headers_with_async_set_header_in_sync_context(self): + """Test set_headers raises RuntimeError when encountering async set_header in sync context.""" + response = MockResponseAsyncSetHeader() + with self.assertRaises(RuntimeError): + self.secure.set_headers(response) + def test_set_headers_overwrites_existing_headers(self): """Test that existing headers are overwritten by Secure.""" secure_headers = Secure.with_default_headers() @@ -347,10 +366,10 @@ def test_invalid_preset(self): def test_empty_secure_instance(self): """Test that an empty Secure instance does not set any headers.""" - secure_headers = Secure() + self.secure = Secure() response = MockResponse() - secure_headers.set_headers(response) + self.secure.set_headers(response) self.assertEqual(len(response.headers), 0) def test_multiple_custom_headers(self): @@ -430,16 +449,10 @@ def test_set_headers_async_with_sync_set_header(self): async def mock_set_headers(): await secure_headers.set_headers_async(response) - import asyncio - asyncio.run(mock_set_headers()) # Verify that headers are set using set_header method - self.assertIn("Strict-Transport-Security", response.header_storage) - self.assertEqual( - response.header_storage["Strict-Transport-Security"], - "max-age=31536000", - ) + self.assertEqual(response.header_storage, secure_headers.headers) def test_set_headers_with_no_headers_or_set_header(self): """Test that an error is raised when response lacks both headers and set_header."""