Skip to content

Commit

Permalink
Fix add_router.
Browse files Browse the repository at this point in the history
. Fix handling with path from import routes.
  • Loading branch information
tarsil committed Oct 26, 2022
1 parent bc2720b commit 0cbd571
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 100 deletions.
10 changes: 8 additions & 2 deletions esmerald/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def __init__(
configurations=self.scheduler_configurations,
)

self.activate_openapi()

def activate_openapi(self) -> None:
if self.openapi_config and self.enable_openapi:
self.openapi_schema = self.openapi_config.create_openapi_schema_model(self)
gateway = gateways.Gateway(handler=self.openapi_config.openapi_apiview)
Expand Down Expand Up @@ -324,9 +327,9 @@ def add_router(self, router: "Router"):
)

if self.on_startup:
self.on_startup.append(router.on_startup)
self.on_startup.extend(router.on_startup)
if self.on_shutdown:
self.on_shutdown.append(router.on_shutdown)
self.on_shutdown.extend(router.on_shutdown)

self.router.routes.append(
gateway(
Expand All @@ -338,9 +341,12 @@ def add_router(self, router: "Router"):
permissions=route.permissions,
handler=route.handler,
parent=self.router,
is_from_router=True,
)
)

self.activate_openapi()

def get_default_exception_handlers(self) -> None:
"""
Default exception handlers added to the application.
Expand Down
4 changes: 1 addition & 3 deletions esmerald/backgound.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@


class BackgroundTask(StarletteBackgroundTask):
def __init__(
self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> None:
def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None:
super().__init__(func, *args, **kwargs)


Expand Down
43 changes: 11 additions & 32 deletions esmerald/kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ class ParameterDefinition(NamedTuple):
class Dependency:
__slots__ = ("key", "inject", "dependencies")

def __init__(
self, key: str, inject: Inject, dependencies: List["Dependency"]
) -> None:
def __init__(self, key: str, inject: Inject, dependencies: List["Dependency"]) -> None:
self.key = key
self.inject = inject
self.dependencies = dependencies
Expand Down Expand Up @@ -202,9 +200,7 @@ def create_for_signature_model(
signature_model_fields=signature_model.__fields__,
)

expected_path_parameters = {
p for p in param_definitions if p.param_type == ParamType.PATH
}
expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH}
expected_header_parameters = {
p for p in param_definitions if p.param_type == ParamType.HEADER
}
Expand Down Expand Up @@ -258,9 +254,7 @@ def create_for_signature_model(
expected_form_data=expected_form_data,
dependency_kwargs_model=dependency_kwargs_model,
)
expected_reserved_kwargs.update(
dependency_kwargs_model.expected_reserved_kwargs
)
expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs)

return KwargsModel(
expected_form_data=expected_form_data,
Expand All @@ -269,9 +263,7 @@ def create_for_signature_model(
expected_query_params=expected_query_parameters,
expected_cookie_params=expected_cookie_parameters,
expected_header_params=expected_header_parameters,
expected_reserved_kwargs=cast(
"Set[ReservedKwargs]", expected_reserved_kwargs
),
expected_reserved_kwargs=cast("Set[ReservedKwargs]", expected_reserved_kwargs),
sequence_query_parameter_names=sequence_query_parameter_names,
is_data_optional=is_optional(signature_model.__fields__["data"])
if "data" in expected_reserved_kwargs
Expand All @@ -280,8 +272,7 @@ def create_for_signature_model(

def to_kwargs(self, connection: Union["WebSocket", "Request"]) -> Dict[str, Any]:
connection_query_params = {
k: self._sequence_or_scalar_param(k, v)
for k, v in connection.query_params.items()
k: self._sequence_or_scalar_param(k, v) for k, v in connection.query_params.items()
}

query_params = self._collect_params(
Expand Down Expand Up @@ -323,9 +314,7 @@ def to_kwargs(self, connection: Union["WebSocket", "Request"]) -> Dict[str, Any]
if "socket" in self.expected_reserved_kwargs:
reserved_kwargs["socket"] = connection
if "data" in self.expected_reserved_kwargs:
reserved_kwargs["data"] = self._get_request_data(
request=cast("Request", connection)
)
reserved_kwargs["data"] = self._get_request_data(request=cast("Request", connection))
return {
**reserved_kwargs,
**path_params,
Expand All @@ -340,17 +329,13 @@ def _collect_params(
) -> Dict[str, Any]:
"""Collects request params, checking for missing required values."""
missing_params = [
p.field_alias
for p in expected
if p.is_required and p.field_alias not in params
p.field_alias for p in expected if p.is_required and p.field_alias not in params
]
if missing_params:
raise ValidationErrorException(
f"Missing required parameter(s) {', '.join(missing_params)} for url {url}"
)
return {
p.field_name: params.get(p.field_alias, p.default_value) for p in expected
}
return {p.field_name: params.get(p.field_alias, p.default_value) for p in expected}

async def resolve_dependency(
self,
Expand All @@ -369,9 +354,7 @@ async def resolve_dependency(
return await dependency.inject(**dependency_kwargs)

@classmethod
def _create_dependency_graph(
cls, key: str, dependencies: Dict[str, Inject]
) -> Dependency:
def _create_dependency_graph(cls, key: str, dependencies: Dict[str, Inject]) -> Dependency:
inject = dependencies[key]
sub_dependency_keys = [
k for k in get_signature_model(inject).__fields__ if k in dependencies
Expand All @@ -395,9 +378,7 @@ def _create_parameter_definition(
) -> ParameterDefinition:
extra = field_info.extra
is_required = extra.get(EXTRA_KEY_REQUIRED, True)
default_value = (
field_info.default if field_info.default is not Undefined else None
)
default_value = field_info.default if field_info.default is not Undefined else None

field_alias = extra.get(ParamType.QUERY) or field_name
param_type = getattr(field_info, "in_", ParamType.QUERY)
Expand Down Expand Up @@ -483,9 +464,7 @@ def _validate_raw_kwargs(
f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}"
)

def _sequence_or_scalar_param(
self, key: str, value: List[str]
) -> Union[str, List[str]]:
def _sequence_or_scalar_param(self, key: str, value: List[str]) -> Union[str, List[str]]:
return (
value[0]
if key not in self.sequence_query_parameter_names and len(value) == 1
Expand Down
12 changes: 3 additions & 9 deletions esmerald/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
_false_values = {"False", "false"}


def _query_param_reducer(
acc: Dict[str, List[str]], cur: Tuple[str, str]
) -> Dict[str, List[str]]:
def _query_param_reducer(acc: Dict[str, List[str]], cur: Tuple[str, str]) -> Dict[str, List[str]]:
key, value = cur

if value in _true_values:
Expand All @@ -43,18 +41,14 @@ def parse_query_params(connection: "HTTPConnection") -> Dict[str, Any]:
return reduce(
_query_param_reducer,
parse_qsl(
query_string
if isinstance(query_string, str)
else query_string.decode("latin-1"),
query_string if isinstance(query_string, str) else query_string.decode("latin-1"),
keep_blank_values=True,
),
{},
)


def parse_form_data(
media_type: "EncodingType", form_data: "FormData", field: "ModelField"
) -> Any:
def parse_form_data(media_type: "EncodingType", form_data: "FormData", field: "ModelField") -> Any:
values_dict: Dict[str, Any] = {}
for key, value in form_data.multi_items():
if not isinstance(value, (UploadFile, StarletteUploadFile)):
Expand Down
48 changes: 12 additions & 36 deletions esmerald/routing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ class PathParameterSchema(TypedDict):


class OpenAPIDefinitionMixin:
def parse_path(
self, path: str
) -> Tuple[str, str, List[Union[str, PathParameterSchema]]]:
def parse_path(self, path: str) -> Tuple[str, str, List[Union[str, PathParameterSchema]]]:
"""
Using the Starlette CONVERTORS and the application registered convertors,
transforms the path into a PathParameterSchema used for the OpenAPI definition.
Expand Down Expand Up @@ -183,9 +181,7 @@ def create_response_handler(
status_code: Optional[int] = None,
media_type: Optional[str] = MediaType.TEXT,
) -> "AsyncAnyCallable":
async def response_content(
data: Response, **kwargs: Dict[str, Any]
) -> StarletteResponse:
async def response_content(data: Response, **kwargs: Dict[str, Any]) -> StarletteResponse:
_cookies = self.get_cookies(data.cookies, cookies)
_headers = {
**self.get_headers(headers),
Expand Down Expand Up @@ -213,9 +209,7 @@ def create_json_response_handler(
) -> "AsyncAnyCallable":
"""Creates a handler function for Esmerald JSON responses"""

async def response_content(
data: Response, **kwargs: Dict[str, Any]
) -> StarletteResponse:
async def response_content(data: Response, **kwargs: Dict[str, Any]) -> StarletteResponse:
if status_code:
data.status_code = status_code
return data
Expand Down Expand Up @@ -260,9 +254,7 @@ def create_handler(
response_class: "ResponseType",
status_code: int,
) -> "AsyncAnyCallable":
async def response_content(
data: Any, **kwargs: Dict[str, Any]
) -> StarletteResponse:
async def response_content(data: Any, **kwargs: Dict[str, Any]) -> StarletteResponse:

data = await self.get_response_data(data=data)
_cookies = self.get_cookies(cookies, [])
Expand Down Expand Up @@ -372,18 +364,14 @@ def get_response_handler(self) -> Callable[[Any], Awaitable[StarletteResponse]]:
"""
if self._response_handler is Void:
media_type = (
self.media_type.value
if isinstance(self.media_type, Enum)
else self.media_type
self.media_type.value if isinstance(self.media_type, Enum) else self.media_type
)

response_class = self.get_response_class()
headers = self.get_response_headers()
cookies = self.get_response_cookies()

if is_class_and_subclass(
self.signature.return_annotation, ResponseContainer
):
if is_class_and_subclass(self.signature.return_annotation, ResponseContainer):
handler = self.create_response_container_handler(
cookies=cookies,
media_type=self.media_type,
Expand All @@ -394,19 +382,15 @@ def get_response_handler(self) -> Callable[[Any], Awaitable[StarletteResponse]]:
self.signature.return_annotation,
(JSONResponse, ORJSONResponse, UJSONResponse),
):
handler = self.create_json_response_handler(
status_code=self.status_code
)
handler = self.create_json_response_handler(status_code=self.status_code)
elif is_class_and_subclass(self.signature.return_annotation, Response):
handler = self.create_response_handler(
cookies=cookies,
status_code=self.status_code,
media_type=self.media_type,
headers=headers,
)
elif is_class_and_subclass(
self.signature.return_annotation, StarletteResponse
):
elif is_class_and_subclass(self.signature.return_annotation, StarletteResponse):
handler = self.create_starlette_response_handler(
cookies=cookies,
media_type=self.media_type,
Expand Down Expand Up @@ -454,9 +438,7 @@ def normalised_path_params(self) -> List[Dict[str, str]]:
Gets the path parameters in a PathParameterSchema format.
"""
path_components = self.parse_path(self.path)
parameters = [
component for component in path_components if isinstance(component, dict)
]
parameters = [component for component in path_components if isinstance(component, dict)]
return parameters

@property
Expand All @@ -475,9 +457,7 @@ def parent_layers(self) -> List[Union[T, "ParentType"]]:
def dependency_names(self) -> Set[str]:
"""A unique set of all dependency names provided in the handlers parent
layers."""
layered_dependencies = (
layer.dependencies or {} for layer in self.parent_layers
)
layered_dependencies = (layer.dependencies or {} for layer in self.parent_layers)
return {name for layer in layered_dependencies for name in layer.keys()}

def resolve_permissions(self) -> List["Permission"]:
Expand Down Expand Up @@ -515,9 +495,7 @@ def get_dependencies(self) -> Dict[str, Inject]:
return cast("Dict[str, Inject]", self._dependencies)

@staticmethod
def has_dependency_unique(
dependencies: Dict[str, Inject], key: str, injector: Inject
) -> None:
def has_dependency_unique(dependencies: Dict[str, Inject], key: str, injector: Inject) -> None:
"""
Validates that a given inject has not been already defined under a
different key in any of the layers.
Expand Down Expand Up @@ -551,9 +529,7 @@ def get_cookies(
filtered_cookies.append(cookie)
normalized_cookies: List[Dict[str, Any]] = []
for cookie in filtered_cookies:
normalized_cookies.append(
cookie.dict(exclude_none=True, exclude={"description"})
)
normalized_cookies.append(cookie.dict(exclude_none=True, exclude={"description"}))
return normalized_cookies

def get_headers(self, headers: "ResponseHeaders") -> Dict[str, Any]:
Expand Down
7 changes: 6 additions & 1 deletion esmerald/routing/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@ def __init__(
permissions: Optional["Permission"] = None,
exception_handlers: Optional["ExceptionHandlers"] = None,
deprecated: Optional[bool] = None,
is_from_router: bool = False,
) -> None:
if not path:
path = "/"
if is_class_and_subclass(handler, APIView):
handler = handler(parent=self)
self.path = clean_path(path + handler.path)

if not is_from_router:
self.path = clean_path(path + handler.path)
else:
self.path = clean_path(path)
self.methods = getattr(handler, "methods", None)

if not name:
Expand Down
15 changes: 4 additions & 11 deletions esmerald/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,9 @@ def is_server_error(cls, error: "ErrorDict") -> bool:
return error["loc"][-1] in cls.dependency_names

@staticmethod
def get_connection_method_and_url(
connection: Union[Request, WebSocket]
) -> Tuple[str, "URL"]:
def get_connection_method_and_url(connection: Union[Request, WebSocket]) -> Tuple[str, "URL"]:
"""Extract method and URL from Request or WebSocket."""
method = (
ScopeType.WEBSOCKET
if isinstance(connection, WebSocket)
else connection.method
)
method = ScopeType.WEBSOCKET if isinstance(connection, WebSocket) else connection.method
return method, connection.url


Expand Down Expand Up @@ -200,9 +194,8 @@ def signature_parameters(self) -> Generator[SignatureParameter, None, None]:
yield SignatureParameter(self.fn_name, name, parameter)

def should_skip_parameter_validation(self, parameter: SignatureParameter) -> bool:
return (
parameter.name in SKIP_VALIDATION_NAMES
or should_skip_dependency_validation(parameter.default)
return parameter.name in SKIP_VALIDATION_NAMES or should_skip_dependency_validation(
parameter.default
)

def create_signature_model(self) -> Type[SignatureModel]:
Expand Down
4 changes: 1 addition & 3 deletions esmerald/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[
while isinstance(value, functools.partial):
value = value.func # type: ignore[unreachable]

return asyncio.iscoroutinefunction(value) or asyncio.iscoroutinefunction(
value.__call__
) # ty
return asyncio.iscoroutinefunction(value) or asyncio.iscoroutinefunction(value.__call__) # ty


def is_class_and_subclass(value: typing.Any, type_: typing.Any) -> bool:
Expand Down
4 changes: 1 addition & 3 deletions esmerald/utils/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await self.fn(*args, **kwargs)


def as_async_callable_list(
value: Union[Callable, List[Callable]]
) -> List[AsyncCallable]:
def as_async_callable_list(value: Union[Callable, List[Callable]]) -> List[AsyncCallable]:
if not isinstance(value, list):
return [AsyncCallable(value)]
return [AsyncCallable(v) for v in value]

0 comments on commit 0cbd571

Please sign in to comment.