Skip to content

Commit

Permalink
Address review items
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-zywicki committed Nov 23, 2021
1 parent acc2b78 commit 256cfff
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 62 deletions.
3 changes: 2 additions & 1 deletion asynction/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class MessageAckValidationException(ValidationException):

class SecurityException(AsynctionException, ConnectionRefusedError):
"""
Base Security Exception type.
Raised when an incoming connection fails to meet the requirements of
any of the specified security schemes.
"""

pass
2 changes: 1 addition & 1 deletion asynction/mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def from_spec(

def _register_handlers(
self,
server_security: Sequence[SecurityRequirement],
server_security: Sequence[SecurityRequirement] = (),
default_error_handler: Optional[ErrorHandler] = None,
) -> None:
for namespace, channel in self.spec.channels.items():
Expand Down
29 changes: 9 additions & 20 deletions asynction/security.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import base64
import http.cookies
from functools import wraps
from http.cookies import SimpleCookie
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

from flask import Request
from flask import request as current_flask_request
Expand All @@ -17,6 +16,7 @@
from asynction.types import SecurityRequirement
from asynction.types import SecurityScheme
from asynction.types import SecuritySchemesType
from asynction.utils import load_handler

TokenInfoFunc = Callable[[str], Mapping]
BasicInfoFunc = Callable[[str, str, Optional[Sequence[str]]], Mapping]
Expand Down Expand Up @@ -58,7 +58,7 @@ def extract_auth_header(request: Request) -> Optional[Tuple[str, str]]:

def validate_basic(
request: Request, basic_info_func: BasicInfoFunc, required_scopes: Sequence[str]
) -> Union[Mapping, None]:
) -> Optional[Mapping]:
auth = extract_auth_header(request)
if not auth:
return None
Expand Down Expand Up @@ -86,7 +86,7 @@ def validate_basic(

def validate_authorization_header(
request: Request, token_info_func: TokenInfoFunc
) -> Union[Mapping, None]:
) -> Optional[Mapping]:
"""Check that the provided request contains a properly formatted Authorization
header and invokes the token_info_func on the token inside of the header.
"""
Expand All @@ -112,7 +112,7 @@ def validate_api_key(
api_key_info_func: APIKeyInfoFunc,
required_scopes: Sequence[str],
bearer_format: Optional[str] = None,
) -> Union[Mapping, None]:
) -> Optional[Mapping]:
"""
Adapted from: https://github.com/zalando/connexion/blob/main/connexion/security/security_handler_factory.py#L221 # noqa: 501
"""
Expand Down Expand Up @@ -146,9 +146,6 @@ def validate_scopes(


def load_scope_validate_func(scheme: SecurityScheme) -> ScopeValidateFunc:
# importing here because doing it at the top leads to a circular import
from asynction.server import load_handler

scope_validate_func = None
if scheme.x_scope_validate_func:
scope_validate_func = load_handler(scheme.x_scope_validate_func)
Expand All @@ -160,9 +157,6 @@ def load_scope_validate_func(scheme: SecurityScheme) -> ScopeValidateFunc:


def load_basic_info_func(scheme: SecurityScheme) -> BasicInfoFunc:
# importing here because doing it at the top leads to a circular import
from asynction.server import load_handler

if scheme.x_basic_info_func is not None:
basic_info_func = load_handler(scheme.x_basic_info_func)
if not basic_info_func:
Expand All @@ -173,9 +167,6 @@ def load_basic_info_func(scheme: SecurityScheme) -> BasicInfoFunc:


def load_token_info_func(scheme: SecurityScheme) -> TokenInfoFunc:
# importing here because doing it at the top leads to a circular import
from asynction.server import load_handler

if scheme.x_token_info_func is not None:
token_info_func = load_handler(scheme.x_token_info_func)
if not token_info_func:
Expand All @@ -186,9 +177,6 @@ def load_token_info_func(scheme: SecurityScheme) -> TokenInfoFunc:


def load_api_key_info_func(scheme: SecurityScheme) -> APIKeyInfoFunc:
# importing here because doing it at the top leads to a circular import
from asynction.server import load_handler

if scheme.x_api_key_info_func is not None:
token_info_func = load_handler(scheme.x_api_key_info_func)
if not token_info_func:
Expand Down Expand Up @@ -252,15 +240,15 @@ def http_bearer_security_check(
return None


def get_cookie_value(cookies, name):
def get_cookie_value(cookies: str, name: str) -> Optional[str]:
"""
Returns cookie value by its name. None if no such value.
:param cookies: str: cookies raw data
:param name: str: cookies key
Borrowed from https://github.com/zalando/connexion/blob/main/connexion/security/security_handler_factory.py#L206 # noqa: 501
"""
cookie_parser = http.cookies.SimpleCookie()
cookie_parser: SimpleCookie = SimpleCookie()
cookie_parser.load(str(cookies))
try:
return cookie_parser[name].value
Expand Down Expand Up @@ -293,7 +281,8 @@ def http_api_key_security_check(request: Request) -> InternalSecurityCheckRespon
api_key = request.headers.get(api_key_name)
elif api_key_in == "cookie":
cookies_list = request.headers.get("Cookie")
api_key = get_cookie_value(cookies_list, api_key_name)
if cookies_list and api_key_name is not None:
api_key = get_cookie_value(cookies_list, api_key_name)
else:
return None, None

Expand Down
12 changes: 2 additions & 10 deletions asynction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
server with an additional factory classmethod.
"""
from functools import singledispatch
from importlib import import_module
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Optional
from typing import Sequence
from urllib.parse import urlparse
Expand All @@ -26,6 +24,7 @@
from asynction.types import ErrorHandler
from asynction.types import JSONMapping
from asynction.types import SecurityRequirement
from asynction.utils import load_handler
from asynction.validation import bindings_validator_factory
from asynction.validation import callback_validator_factory
from asynction.validation import publish_message_validator_factory
Expand Down Expand Up @@ -78,13 +77,6 @@ def load_spec(spec_path: Path) -> AsyncApiSpec:
return AsyncApiSpec.from_dict(raw_resolved)


def load_handler(handler_id: str) -> Callable:
*module_path_elements, object_name = handler_id.split(".")
module = import_module(".".join(module_path_elements))

return getattr(module, object_name)


def _noop_handler(*args, **kwargs) -> None:
return None

Expand Down Expand Up @@ -228,7 +220,7 @@ def _register_namespace_handlers(

def _register_handlers(
self,
server_security: Sequence[SecurityRequirement],
server_security: Sequence[SecurityRequirement] = (),
default_error_handler: Optional[ErrorHandler] = None,
) -> None:
for namespace, channel in self.spec.channels.items():
Expand Down
9 changes: 4 additions & 5 deletions asynction/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ def __post_init__(self):
raise ValueError("Authorization code OAuth flow is missing Token URL")

def supported_scopes(self) -> Iterator[str]:
# note Cannot lru_cache this
# TypeError: unhashable type: 'OAuth2Flows'
for f in fields(self): # dataclasses.fields
for f in fields(self):
flow = getattr(self, f.name)
if flow:
for scope in flow.scopes:
Expand Down Expand Up @@ -184,7 +182,8 @@ def __post_init__(self):
options = ["query", "header", "cookie"]
if not self.in_ or self.in_ not in options:
raise ValueError(
f'"in" field must be one of {options} for {self.type} security schemes' # noqa: 501
f'"in" field must be one of {options} '
f"for {self.type} security schemes"
)
if not self.name:
raise ValueError(f'"name" is required for {self.type} security schemes')
Expand Down Expand Up @@ -472,7 +471,7 @@ def __post_init__(self):
if scope not in supported_scopes:
raise ValueError(
f"OAuth {scope} is not defined within "
"the {security_scheme_name} security scheme"
f"the {security_scheme_name} security scheme"
)

@staticmethod
Expand Down
9 changes: 9 additions & 0 deletions asynction/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from importlib import import_module
from typing import Callable


def load_handler(handler_id: str) -> Callable:
*module_path_elements, object_name = handler_id.split(".")
module = import_module(".".join(module_path_elements))

return getattr(module, object_name)
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Exceptions
.. autoexception:: asynction.PayloadValidationException
.. autoexception:: asynction.BindingsValidationException
.. autoexception:: asynction.MessageAckValidationException
.. autoexception:: asynction.SecurityException

Indices and tables
==================
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/test_mock_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_register_handlers_registers_noop_handler_for_message_with_no_ack(
)
server = new_mock_asynction_socket_io(spec)

server._register_handlers([])
server._register_handlers()
assert len(server.handlers) == 2 # connect handler included as well
registered_event, registered_handler, registered_namespace = server.handlers[0]
assert registered_event == event_name
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_register_handlers_registers_valid_handler_for_message_with_ack(
)
server = new_mock_asynction_socket_io(spec)

server._register_handlers([])
server._register_handlers()
assert len(server.handlers) == 2 # connect handler included as well
registered_event, registered_handler, registered_namespace = server.handlers[0]
assert registered_event == event_name
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_register_handlers_adds_payload_validator_if_validation_is_enabled(
)
server = new_mock_asynction_socket_io(spec)

server._register_handlers([])
server._register_handlers()
_, registered_handler, _ = server.handlers[0]
handler_with_validation = deep_unwrap(registered_handler, depth=1)
actual_handler = deep_unwrap(handler_with_validation)
Expand All @@ -282,7 +282,7 @@ def test_register_handlers_registers_connection_handler(
)
server = new_mock_asynction_socket_io(spec)

server._register_handlers([])
server._register_handlers()

assert len(server.handlers) == 1
registered_event, registered_handler, registered_namespace = server.handlers[0]
Expand Down Expand Up @@ -312,7 +312,7 @@ def test_register_handlers_registers_connection_handler_with_bindings_validation
server = new_mock_asynction_socket_io(spec)
flask_app = Flask(__name__)

server._register_handlers([])
server._register_handlers()
_, registered_handler, _ = server.handlers[0]

handler_with_validation = deep_unwrap(registered_handler, depth=1)
Expand All @@ -338,7 +338,7 @@ def test_register_handlers_registers_default_error_handler(
AsyncApiSpec(asyncapi=faker.pystr(), info=server_info, channels={})
)

server._register_handlers([], optional_error_handler)
server._register_handlers(default_error_handler=optional_error_handler)
assert server.default_exception_handler == optional_error_handler


Expand Down Expand Up @@ -373,7 +373,7 @@ def test_run_spawns_background_tasks_and_calls_super_run(
)
flask_app = Flask(__name__)
server = new_mock_asynction_socket_io(spec, flask_app)
server._register_handlers([])
server._register_handlers()

background_tasks: MutableSequence[MockThread] = []

Expand Down Expand Up @@ -426,7 +426,7 @@ def start_background_task_mock(target, *args, **kwargs):

flask_app = Flask(__name__)
server = new_mock_asynction_socket_io(spec, flask_app)
server._register_handlers([])
server._register_handlers()

with patch.object(SocketIO, "run"):
with patch.object(server, "start_background_task", start_background_task_mock):
Expand Down Expand Up @@ -469,7 +469,7 @@ def start_background_task_mock(target, *args, **kwargs):

flask_app = Flask(__name__)
server = new_mock_asynction_socket_io(spec, flask_app)
server._register_handlers([])
server._register_handlers()

with patch.object(SocketIO, "run"):
with patch.object(server, "start_background_task", start_background_task_mock):
Expand Down
13 changes: 4 additions & 9 deletions tests/unit/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from asynction.types import SecurityScheme
from asynction.types import SecuritySchemesType
from tests.fixtures import FixturePaths
from tests.fixtures import handlers


def test_load_basic_info_func():
Expand All @@ -19,9 +20,7 @@ def test_load_basic_info_func():
scheme=HTTPAuthenticationScheme.BASIC,
x_basic_info_func="tests.fixtures.handlers.basic_info",
)
basic_info = load_basic_info_func(scheme)
assert basic_info
assert callable(basic_info)
assert load_basic_info_func(scheme) == handlers.basic_info


def test_load_api_key_info_func(fixture_paths: FixturePaths):
Expand All @@ -31,9 +30,7 @@ def test_load_api_key_info_func(fixture_paths: FixturePaths):
in_="query",
x_api_key_info_func="tests.fixtures.handlers.api_key_info",
)
api_key_info = load_api_key_info_func(scheme)
assert api_key_info
assert callable(api_key_info)
assert load_api_key_info_func(scheme) == handlers.api_key_info


def test_load_token_info_func(fixture_paths: FixturePaths):
Expand All @@ -42,9 +39,7 @@ def test_load_token_info_func(fixture_paths: FixturePaths):
flows=OAuth2Flows(implicit=OAuth2Flow(authorization_url="", scopes={"a": "a"})),
x_token_info_func="tests.fixtures.handlers.token_info",
)
token_info = load_token_info_func(scheme)
assert token_info
assert callable(token_info)
assert load_token_info_func(scheme) == handlers.token_info


def test_build_basic_http_security_check():
Expand Down
Loading

0 comments on commit 256cfff

Please sign in to comment.