Skip to content

Commit

Permalink
Merge pull request #27 from TypeError/fix/performance-improvement-set…
Browse files Browse the repository at this point in the history
…-headers-v1.0.1

Fix/performance improvement set headers v1.0.1
  • Loading branch information
cak authored Oct 18, 2024
2 parents 5a5d847 + c28882b commit 813fb2d
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 64 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Placeholder for upcoming changes.

## [1.0.1] - 2024-10-18

### Fixed

- Improved performance of `Secure.set_headers` by reducing redundant type checks. ([#26](https://github.com/TypeError/secure/issues/26))

## [1.0.0] - 2024-09-27

### Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "secure"
version = "1.0.0"
version = "1.0.1"
description = "A lightweight package that adds security headers for Python web frameworks."
readme = { file = "README.md", "content-type" = "text/markdown" }
license = { text = "MIT" }
Expand Down
55 changes: 29 additions & 26 deletions secure/secure.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,23 @@ def set_headers(self, response: ResponseProtocol) -> None:
RuntimeError: If an asynchronous 'set_header' method is used in a synchronous context.
AttributeError: If the response object does not support setting headers.
"""
for header_name, header_value in self.headers.items():
if isinstance(response, SetHeaderProtocol):
# If response has set_header method, use it
set_header = response.set_header
if inspect.iscoroutinefunction(set_header):
raise RuntimeError(
"Encountered asynchronous 'set_header' in synchronous context."
)
if isinstance(response, SetHeaderProtocol):
# Use the set_header method if available
set_header = response.set_header
if inspect.iscoroutinefunction(set_header):
raise RuntimeError(
"Encountered asynchronous 'set_header' in synchronous context."
)
for header_name, header_value in self.headers.items():
set_header(header_name, header_value)
elif isinstance(response, HeadersProtocol): # type: ignore
# If response has headers dictionary, use it
elif isinstance(response, HeadersProtocol): # type: ignore
# Use the headers dictionary if available
for header_name, header_value in self.headers.items():
response.headers[header_name] = header_value
else:
raise AttributeError(
f"Response object of type '{type(response).__name__}' does not support setting headers."
)
else:
raise AttributeError(
f"Response object of type '{type(response).__name__}' does not support setting headers."
)

async def set_headers_async(self, response: ResponseProtocol) -> None:
"""
Expand All @@ -266,18 +267,20 @@ async def set_headers_async(self, response: ResponseProtocol) -> None:
Raises:
AttributeError: If the response object does not support setting headers.
"""
for header_name, header_value in self.headers.items():
if isinstance(response, SetHeaderProtocol):
# If response has set_header method, use it
set_header = response.set_header
if inspect.iscoroutinefunction(set_header):
if isinstance(response, SetHeaderProtocol):
# Use the set_header method if available
set_header = response.set_header
if inspect.iscoroutinefunction(set_header):
for header_name, header_value in self.headers.items():
await set_header(header_name, header_value)
else:
else:
for header_name, header_value in self.headers.items():
set_header(header_name, header_value)
elif isinstance(response, HeadersProtocol): # type: ignore
# If response has headers dictionary, use it
elif isinstance(response, HeadersProtocol): # type: ignore
# Use the headers dictionary if available
for header_name, header_value in self.headers.items():
response.headers[header_name] = header_value
else:
raise AttributeError(
f"Response object of type '{type(response).__name__}' does not support setting headers."
)
else:
raise AttributeError(
f"Response object of type '{type(response).__name__}' does not support setting headers."
)
87 changes: 50 additions & 37 deletions tests/secure/test_secure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import unittest

from secure import (
Expand Down Expand Up @@ -44,6 +45,20 @@ class MockResponseNoHeaders:


class TestSecure(unittest.TestCase):
def setUp(self):
# Initialize Secure with some test headers
self.secure = Secure(
custom=[
CustomHeader("X-Test-Header-1", "Value1"),
CustomHeader("X-Test-Header-2", "Value2"),
]
)
# Precompute headers dictionary
self.secure.headers = {
header.header_name: header.header_value
for header in self.secure.headers_list
}

def test_with_default_headers(self):
"""Test that default headers are correctly applied."""
secure_headers = Secure.with_default_headers()
Expand Down Expand Up @@ -210,8 +225,6 @@ def test_async_set_headers(self):
async def mock_set_headers():
await secure_headers.set_headers_async(response)

import asyncio

asyncio.run(mock_set_headers())

# Verify that headers are set asynchronously
Expand All @@ -235,43 +248,43 @@ async def mock_set_headers():

def test_set_headers_with_set_header_method(self):
"""Test setting headers on a response object with set_header method."""
secure_headers = Secure.with_default_headers()
response = MockResponseWithSetHeader()

# Apply the headers to the response object
secure_headers.set_headers(response)
self.secure.set_headers(response)

# Verify that headers are set using set_header method
self.assertIn("Strict-Transport-Security", response.header_storage)
self.assertEqual(
response.header_storage["Strict-Transport-Security"],
"max-age=31536000",
)
self.assertEqual(response.header_storage, self.secure.headers)
# Ensure set_header was called correct number of times
self.assertEqual(len(response.header_storage), len(self.secure.headers))

self.assertIn("X-Content-Type-Options", response.header_storage)
self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff")
def test_set_headers_with_headers_dict(self):
"""Test set_headers with a response object that has a headers dictionary."""
response = MockResponse()
self.secure.set_headers(response)

def test_async_set_headers_with_async_set_header_method(self):
"""Test async setting headers on a response object with async set_header method."""
secure_headers = Secure.with_default_headers()
response = MockResponseAsyncSetHeader()
# Verify that headers are set
self.assertEqual(response.headers, self.secure.headers)

async def mock_set_headers():
await secure_headers.set_headers_async(response)
def test_set_headers_async_with_async_set_header(self):
"""Test set_headers_async with a response object that has an asynchronous set_header method."""
response = MockResponseAsyncSetHeader()

import asyncio
async def test_async():
await self.secure.set_headers_async(response)

asyncio.run(mock_set_headers())
asyncio.run(test_async())

# Verify that headers are set using async set_header method
self.assertIn("Strict-Transport-Security", response.header_storage)
self.assertEqual(
response.header_storage["Strict-Transport-Security"],
"max-age=31536000",
)
self.assertEqual(response.header_storage, self.secure.headers)
# Ensure set_header was called correct number of times
self.assertEqual(len(response.header_storage), len(self.secure.headers))

def test_set_headers_async_with_headers_dict(self):
"""Test set_headers_async with a response object that has a headers dictionary."""
response = MockResponse()
asyncio.run(self.secure.set_headers_async(response))

self.assertIn("X-Content-Type-Options", response.header_storage)
self.assertEqual(response.header_storage["X-Content-Type-Options"], "nosniff")
# Verify that headers are set
self.assertEqual(response.headers, self.secure.headers)

def test_set_headers_missing_interface(self):
"""Test that an error is raised when response object lacks required methods."""
Expand All @@ -286,6 +299,12 @@ def test_set_headers_missing_interface(self):
str(context.exception),
)

def test_set_headers_with_async_set_header_in_sync_context(self):
"""Test set_headers raises RuntimeError when encountering async set_header in sync context."""
response = MockResponseAsyncSetHeader()
with self.assertRaises(RuntimeError):
self.secure.set_headers(response)

def test_set_headers_overwrites_existing_headers(self):
"""Test that existing headers are overwritten by Secure."""
secure_headers = Secure.with_default_headers()
Expand Down Expand Up @@ -347,10 +366,10 @@ def test_invalid_preset(self):

def test_empty_secure_instance(self):
"""Test that an empty Secure instance does not set any headers."""
secure_headers = Secure()
self.secure = Secure()
response = MockResponse()

secure_headers.set_headers(response)
self.secure.set_headers(response)
self.assertEqual(len(response.headers), 0)

def test_multiple_custom_headers(self):
Expand Down Expand Up @@ -430,16 +449,10 @@ def test_set_headers_async_with_sync_set_header(self):
async def mock_set_headers():
await secure_headers.set_headers_async(response)

import asyncio

asyncio.run(mock_set_headers())

# Verify that headers are set using set_header method
self.assertIn("Strict-Transport-Security", response.header_storage)
self.assertEqual(
response.header_storage["Strict-Transport-Security"],
"max-age=31536000",
)
self.assertEqual(response.header_storage, secure_headers.headers)

def test_set_headers_with_no_headers_or_set_header(self):
"""Test that an error is raised when response lacks both headers and set_header."""
Expand Down

0 comments on commit 813fb2d

Please sign in to comment.