diff --git a/asynction/exceptions.py b/asynction/exceptions.py index 80cfde0..66f8a2f 100644 --- a/asynction/exceptions.py +++ b/asynction/exceptions.py @@ -42,7 +42,8 @@ class MessageAckValidationException(ValidationException): class SecurityException(AsynctionException, ConnectionRefusedError): """ - Base Security Exception type. + Raised when an incoming connection fails to meet the requirements of + any of the specified security schemes. """ pass diff --git a/asynction/mock_server.py b/asynction/mock_server.py index c680f5a..84ef837 100644 --- a/asynction/mock_server.py +++ b/asynction/mock_server.py @@ -210,7 +210,7 @@ def from_spec( def _register_handlers( self, - server_security: Sequence[SecurityRequirement], + server_security: Sequence[SecurityRequirement] = (), default_error_handler: Optional[ErrorHandler] = None, ) -> None: for namespace, channel in self.spec.channels.items(): diff --git a/asynction/security.py b/asynction/security.py index 8bb5602..ff20245 100644 --- a/asynction/security.py +++ b/asynction/security.py @@ -1,13 +1,12 @@ import base64 -import http.cookies from functools import wraps +from http.cookies import SimpleCookie from typing import Callable from typing import List from typing import Mapping from typing import Optional from typing import Sequence from typing import Tuple -from typing import Union from flask import Request from flask import request as current_flask_request @@ -17,6 +16,7 @@ from asynction.types import SecurityRequirement from asynction.types import SecurityScheme from asynction.types import SecuritySchemesType +from asynction.utils import load_handler TokenInfoFunc = Callable[[str], Mapping] BasicInfoFunc = Callable[[str, str, Optional[Sequence[str]]], Mapping] @@ -58,7 +58,7 @@ def extract_auth_header(request: Request) -> Optional[Tuple[str, str]]: def validate_basic( request: Request, basic_info_func: BasicInfoFunc, required_scopes: Sequence[str] -) -> Union[Mapping, None]: +) -> Optional[Mapping]: auth = extract_auth_header(request) if not auth: return None @@ -86,7 +86,7 @@ def validate_basic( def validate_authorization_header( request: Request, token_info_func: TokenInfoFunc -) -> Union[Mapping, None]: +) -> Optional[Mapping]: """Check that the provided request contains a properly formatted Authorization header and invokes the token_info_func on the token inside of the header. """ @@ -112,7 +112,7 @@ def validate_api_key( api_key_info_func: APIKeyInfoFunc, required_scopes: Sequence[str], bearer_format: Optional[str] = None, -) -> Union[Mapping, None]: +) -> Optional[Mapping]: """ Adapted from: https://github.com/zalando/connexion/blob/main/connexion/security/security_handler_factory.py#L221 # noqa: 501 """ @@ -146,9 +146,6 @@ def validate_scopes( def load_scope_validate_func(scheme: SecurityScheme) -> ScopeValidateFunc: - # importing here because doing it at the top leads to a circular import - from asynction.server import load_handler - scope_validate_func = None if scheme.x_scope_validate_func: scope_validate_func = load_handler(scheme.x_scope_validate_func) @@ -160,9 +157,6 @@ def load_scope_validate_func(scheme: SecurityScheme) -> ScopeValidateFunc: def load_basic_info_func(scheme: SecurityScheme) -> BasicInfoFunc: - # importing here because doing it at the top leads to a circular import - from asynction.server import load_handler - if scheme.x_basic_info_func is not None: basic_info_func = load_handler(scheme.x_basic_info_func) if not basic_info_func: @@ -173,9 +167,6 @@ def load_basic_info_func(scheme: SecurityScheme) -> BasicInfoFunc: def load_token_info_func(scheme: SecurityScheme) -> TokenInfoFunc: - # importing here because doing it at the top leads to a circular import - from asynction.server import load_handler - if scheme.x_token_info_func is not None: token_info_func = load_handler(scheme.x_token_info_func) if not token_info_func: @@ -186,9 +177,6 @@ def load_token_info_func(scheme: SecurityScheme) -> TokenInfoFunc: def load_api_key_info_func(scheme: SecurityScheme) -> APIKeyInfoFunc: - # importing here because doing it at the top leads to a circular import - from asynction.server import load_handler - if scheme.x_api_key_info_func is not None: token_info_func = load_handler(scheme.x_api_key_info_func) if not token_info_func: @@ -252,7 +240,7 @@ def http_bearer_security_check( return None -def get_cookie_value(cookies, name): +def get_cookie_value(cookies: str, name: str) -> Optional[str]: """ Returns cookie value by its name. None if no such value. :param cookies: str: cookies raw data @@ -260,7 +248,7 @@ def get_cookie_value(cookies, name): Borrowed from https://github.com/zalando/connexion/blob/main/connexion/security/security_handler_factory.py#L206 # noqa: 501 """ - cookie_parser = http.cookies.SimpleCookie() + cookie_parser: SimpleCookie = SimpleCookie() cookie_parser.load(str(cookies)) try: return cookie_parser[name].value @@ -293,7 +281,8 @@ def http_api_key_security_check(request: Request) -> InternalSecurityCheckRespon api_key = request.headers.get(api_key_name) elif api_key_in == "cookie": cookies_list = request.headers.get("Cookie") - api_key = get_cookie_value(cookies_list, api_key_name) + if cookies_list and api_key_name is not None: + api_key = get_cookie_value(cookies_list, api_key_name) else: return None, None diff --git a/asynction/server.py b/asynction/server.py index aca8359..ffc4270 100644 --- a/asynction/server.py +++ b/asynction/server.py @@ -3,10 +3,8 @@ server with an additional factory classmethod. """ from functools import singledispatch -from importlib import import_module from pathlib import Path from typing import Any -from typing import Callable from typing import Optional from typing import Sequence from urllib.parse import urlparse @@ -26,6 +24,7 @@ from asynction.types import ErrorHandler from asynction.types import JSONMapping from asynction.types import SecurityRequirement +from asynction.utils import load_handler from asynction.validation import bindings_validator_factory from asynction.validation import callback_validator_factory from asynction.validation import publish_message_validator_factory @@ -78,13 +77,6 @@ def load_spec(spec_path: Path) -> AsyncApiSpec: return AsyncApiSpec.from_dict(raw_resolved) -def load_handler(handler_id: str) -> Callable: - *module_path_elements, object_name = handler_id.split(".") - module = import_module(".".join(module_path_elements)) - - return getattr(module, object_name) - - def _noop_handler(*args, **kwargs) -> None: return None @@ -228,7 +220,7 @@ def _register_namespace_handlers( def _register_handlers( self, - server_security: Sequence[SecurityRequirement], + server_security: Sequence[SecurityRequirement] = (), default_error_handler: Optional[ErrorHandler] = None, ) -> None: for namespace, channel in self.spec.channels.items(): diff --git a/asynction/types.py b/asynction/types.py index 519d4fb..a64f648 100644 --- a/asynction/types.py +++ b/asynction/types.py @@ -114,9 +114,7 @@ def __post_init__(self): raise ValueError("Authorization code OAuth flow is missing Token URL") def supported_scopes(self) -> Iterator[str]: - # note Cannot lru_cache this - # TypeError: unhashable type: 'OAuth2Flows' - for f in fields(self): # dataclasses.fields + for f in fields(self): flow = getattr(self, f.name) if flow: for scope in flow.scopes: @@ -184,7 +182,8 @@ def __post_init__(self): options = ["query", "header", "cookie"] if not self.in_ or self.in_ not in options: raise ValueError( - f'"in" field must be one of {options} for {self.type} security schemes' # noqa: 501 + f'"in" field must be one of {options} ' + f"for {self.type} security schemes" ) if not self.name: raise ValueError(f'"name" is required for {self.type} security schemes') @@ -472,7 +471,7 @@ def __post_init__(self): if scope not in supported_scopes: raise ValueError( f"OAuth {scope} is not defined within " - "the {security_scheme_name} security scheme" + f"the {security_scheme_name} security scheme" ) @staticmethod diff --git a/asynction/utils.py b/asynction/utils.py new file mode 100644 index 0000000..d28850d --- /dev/null +++ b/asynction/utils.py @@ -0,0 +1,9 @@ +from importlib import import_module +from typing import Callable + + +def load_handler(handler_id: str) -> Callable: + *module_path_elements, object_name = handler_id.split(".") + module = import_module(".".join(module_path_elements)) + + return getattr(module, object_name) diff --git a/docs/index.rst b/docs/index.rst index fa763bd..9fe070e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,6 +39,7 @@ Exceptions .. autoexception:: asynction.PayloadValidationException .. autoexception:: asynction.BindingsValidationException .. autoexception:: asynction.MessageAckValidationException +.. autoexception:: asynction.SecurityException Indices and tables ================== diff --git a/tests/unit/test_mock_server.py b/tests/unit/test_mock_server.py index 7c013a2..73708d9 100644 --- a/tests/unit/test_mock_server.py +++ b/tests/unit/test_mock_server.py @@ -173,7 +173,7 @@ def test_register_handlers_registers_noop_handler_for_message_with_no_ack( ) server = new_mock_asynction_socket_io(spec) - server._register_handlers([]) + server._register_handlers() assert len(server.handlers) == 2 # connect handler included as well registered_event, registered_handler, registered_namespace = server.handlers[0] assert registered_event == event_name @@ -224,7 +224,7 @@ def test_register_handlers_registers_valid_handler_for_message_with_ack( ) server = new_mock_asynction_socket_io(spec) - server._register_handlers([]) + server._register_handlers() assert len(server.handlers) == 2 # connect handler included as well registered_event, registered_handler, registered_namespace = server.handlers[0] assert registered_event == event_name @@ -262,7 +262,7 @@ def test_register_handlers_adds_payload_validator_if_validation_is_enabled( ) server = new_mock_asynction_socket_io(spec) - server._register_handlers([]) + server._register_handlers() _, registered_handler, _ = server.handlers[0] handler_with_validation = deep_unwrap(registered_handler, depth=1) actual_handler = deep_unwrap(handler_with_validation) @@ -282,7 +282,7 @@ def test_register_handlers_registers_connection_handler( ) server = new_mock_asynction_socket_io(spec) - server._register_handlers([]) + server._register_handlers() assert len(server.handlers) == 1 registered_event, registered_handler, registered_namespace = server.handlers[0] @@ -312,7 +312,7 @@ def test_register_handlers_registers_connection_handler_with_bindings_validation server = new_mock_asynction_socket_io(spec) flask_app = Flask(__name__) - server._register_handlers([]) + server._register_handlers() _, registered_handler, _ = server.handlers[0] handler_with_validation = deep_unwrap(registered_handler, depth=1) @@ -338,7 +338,7 @@ def test_register_handlers_registers_default_error_handler( AsyncApiSpec(asyncapi=faker.pystr(), info=server_info, channels={}) ) - server._register_handlers([], optional_error_handler) + server._register_handlers(default_error_handler=optional_error_handler) assert server.default_exception_handler == optional_error_handler @@ -373,7 +373,7 @@ def test_run_spawns_background_tasks_and_calls_super_run( ) flask_app = Flask(__name__) server = new_mock_asynction_socket_io(spec, flask_app) - server._register_handlers([]) + server._register_handlers() background_tasks: MutableSequence[MockThread] = [] @@ -426,7 +426,7 @@ def start_background_task_mock(target, *args, **kwargs): flask_app = Flask(__name__) server = new_mock_asynction_socket_io(spec, flask_app) - server._register_handlers([]) + server._register_handlers() with patch.object(SocketIO, "run"): with patch.object(server, "start_background_task", start_background_task_mock): @@ -469,7 +469,7 @@ def start_background_task_mock(target, *args, **kwargs): flask_app = Flask(__name__) server = new_mock_asynction_socket_io(spec, flask_app) - server._register_handlers([]) + server._register_handlers() with patch.object(SocketIO, "run"): with patch.object(server, "start_background_task", start_background_task_mock): diff --git a/tests/unit/test_security.py b/tests/unit/test_security.py index 22c448f..fb42da1 100644 --- a/tests/unit/test_security.py +++ b/tests/unit/test_security.py @@ -11,6 +11,7 @@ from asynction.types import SecurityScheme from asynction.types import SecuritySchemesType from tests.fixtures import FixturePaths +from tests.fixtures import handlers def test_load_basic_info_func(): @@ -19,9 +20,7 @@ def test_load_basic_info_func(): scheme=HTTPAuthenticationScheme.BASIC, x_basic_info_func="tests.fixtures.handlers.basic_info", ) - basic_info = load_basic_info_func(scheme) - assert basic_info - assert callable(basic_info) + assert load_basic_info_func(scheme) == handlers.basic_info def test_load_api_key_info_func(fixture_paths: FixturePaths): @@ -31,9 +30,7 @@ def test_load_api_key_info_func(fixture_paths: FixturePaths): in_="query", x_api_key_info_func="tests.fixtures.handlers.api_key_info", ) - api_key_info = load_api_key_info_func(scheme) - assert api_key_info - assert callable(api_key_info) + assert load_api_key_info_func(scheme) == handlers.api_key_info def test_load_token_info_func(fixture_paths: FixturePaths): @@ -42,9 +39,7 @@ def test_load_token_info_func(fixture_paths: FixturePaths): flows=OAuth2Flows(implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"})), x_token_info_func="tests.fixtures.handlers.token_info", ) - token_info = load_token_info_func(scheme) - assert token_info - assert callable(token_info) + assert load_token_info_func(scheme) == handlers.token_info def test_build_basic_http_security_check(): diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index e4680fc..1e79c30 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -254,7 +254,7 @@ def test_register_handlers_registers_callables_with_correct_event_name_and_names ) server = AsynctionSocketIO(spec, True, True, None) - server._register_handlers([]) + server._register_handlers() assert len(server.handlers) == 1 registered_event, registered_handler, registered_namespace = server.handlers[0] assert registered_event == event_name @@ -282,7 +282,7 @@ def test_register_handlers_registers_channel_handlers( ) server = AsynctionSocketIO(spec, True, True, None) - server._register_handlers([]) + server._register_handlers() assert server.exception_handlers[namespace] == some_error for event_name, handler, handler_namespace in server.handlers: @@ -321,7 +321,7 @@ def test_register_handlers_adds_payload_validator_if_validation_is_enabled( ) server = AsynctionSocketIO(spec, True, True, None) - server._register_handlers([]) + server._register_handlers() _, registered_handler, _ = server.handlers[0] handler_with_validation = deep_unwrap(registered_handler, depth=1) actual_handler = deep_unwrap(handler_with_validation) @@ -365,7 +365,7 @@ def test_register_handlers_adds_ack_validator_if_validation_is_enabled( ) server = AsynctionSocketIO(spec, True, True, None) - server._register_handlers([]) + server._register_handlers() _, registered_handler, _ = server.handlers[0] handler_with_validation = deep_unwrap(registered_handler, depth=1) actual_handler = deep_unwrap(handler_with_validation) @@ -405,7 +405,7 @@ def test_register_handlers_skips_payload_validator_if_validation_is_disabled( ) server = AsynctionSocketIO(spec, False, True, None) - server._register_handlers([]) + server._register_handlers() _, registered_handler, _ = server.handlers[0] handler_with_validation = deep_unwrap(registered_handler, depth=1) actual_handler = deep_unwrap(handler_with_validation) @@ -431,7 +431,7 @@ def test_register_handlers_registers_default_error_handler( None, ) - server._register_handlers([], optional_error_handler) + server._register_handlers(default_error_handler=optional_error_handler) assert server.default_exception_handler == optional_error_handler diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index e8cbac3..e08cde7 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -216,7 +216,21 @@ def test_async_api_spec_from_and_to_dict(faker: Faker): } }, "components": { - "securitySchemes": {"test": {"type": "http", "scheme": "basic"}} + "securitySchemes": { + "test": {"type": "http", "scheme": "basic"}, + "test2": {"type": "http", "scheme": "bearer", "bearerFormat": "JWT"}, + "testApiKey": {"type": "httpApiKey", "name": "test", "in": "header"}, + "oauth2": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://localhost:12345", + "refreshUrl": "https://localhost:12345/refresh", + "scopes": {"a": "A", "b": "B"}, + } + }, + }, + } }, }