Skip to content

Commit

Permalink
fix handling cookies is None in get_cookies
Browse files Browse the repository at this point in the history
Changes:
- fix handling cookies is None in get_cookies
- Optimize get_cookies
- fix wrong example
- remove left-over special handling of an empty string, lilya json
  encoding can do that now
  • Loading branch information
devkral committed Dec 27, 2024
1 parent 1b8dae9 commit 903c4e8
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 42 deletions.
9 changes: 3 additions & 6 deletions esmerald/responses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
)
]
Response(response_cookies=response_cookies)
Response(cookies=response_cookies)
```
"""
),
Expand All @@ -163,12 +163,13 @@ def __init__(
encoders=[PydanticEncoder, MsgSpecEncoder]
]
Response(response_cookies=response_cookies)
Response(cookies=response_cookies)
```
"""
),
] = None,
) -> None:
self.cookies = cookies or []
super().__init__(
content=content,
status_code=status_code,
Expand All @@ -177,7 +178,6 @@ def __init__(
background=cast("BackgroundTask", background),
encoders=encoders,
)
self.cookies = cookies or []

def make_response(self, content: Any) -> bytes | memoryview | str:
if (
Expand All @@ -204,9 +204,6 @@ def make_response(self, content: Any) -> bytes | memoryview | str:
try:
# switch to a special mode for MediaType.JSON (default handlers)
if self.media_type == MediaType.JSON:
# "" should serialize to json
if content == "":
return b'""'
# keep it a serialized json object
transform_kwargs.setdefault("post_transform_fn", None)
# otherwise use default logic of lilya striping '"'
Expand Down
2 changes: 1 addition & 1 deletion esmerald/responses/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
if media_type == MediaType.JSON: # we assume this is the default
suffixes = PurePath(template_name).suffixes
for suffix in suffixes:
_type = guess_type("name" + suffix)[0]
_type = guess_type(f"name{suffix}")[0]
if _type:
media_type = _type
break
Expand Down
79 changes: 44 additions & 35 deletions esmerald/routing/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

from datetime import date, datetime, time, timedelta
from decimal import Decimal
from enum import Enum
from functools import partial
from inspect import Signature, isawaitable
from itertools import chain
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -54,7 +57,13 @@
from esmerald.permissions import BasePermission
from esmerald.permissions.types import Permission
from esmerald.routing.router import HTTPHandler
from esmerald.types import APIGateHandler, Dependencies, ResponseCookies, ResponseHeaders
from esmerald.types import (
APIGateHandler,
Cookie,
Dependencies,
ResponseCookies,
ResponseHeaders,
)
from esmerald.typing import AnyCallable

param_type_map = {
Expand All @@ -74,6 +83,7 @@


T = TypeVar("T", bound="Dispatcher")
_empty: tuple[Any, ...] = ()


class PathParameterSchema(TypedDict):
Expand Down Expand Up @@ -131,7 +141,7 @@ def create_signature_model(self, is_websocket: bool = False) -> None:
else:
self.websocket_parameter_model = transformer_model

def create_handler_transformer_model(self) -> "TransformerModel":
def create_handler_transformer_model(self) -> TransformerModel:
"""Method to create a TransformerModel for a given handler."""
dependencies = self.get_dependencies()
signature_model = get_signature(self)
Expand Down Expand Up @@ -262,10 +272,10 @@ def _get_default_status_code(self, data: Response) -> int:

def _get_response_container_handler(
self,
cookies: "ResponseCookies",
cookies: ResponseCookies,
headers: Dict[str, Any],
media_type: str,
) -> Callable[[ResponseContainer, Type["Esmerald"], Dict[str, Any]], LilyaResponse]:
) -> Callable[[ResponseContainer, Type[Esmerald], Dict[str, Any]], LilyaResponse]:
"""
Creates a handler for ResponseContainer types.
Expand Down Expand Up @@ -300,7 +310,7 @@ async def response_content(
)

def _get_json_response_handler(
self, cookies: "ResponseCookies", headers: Dict[str, Any]
self, cookies: ResponseCookies, headers: Dict[str, Any]
) -> Callable[[Response, Dict[str, Any]], LilyaResponse]:
"""
Creates a handler function for JSON responses.
Expand All @@ -314,7 +324,7 @@ def _get_json_response_handler(
"""

async def response_content(data: Response, **kwargs: Dict[str, Any]) -> LilyaResponse:
_cookies = self.get_cookies(cookies, [])
_cookies = self.get_cookies(cookies)
_headers = {
**self.get_headers(headers),
**data.headers,
Expand All @@ -334,7 +344,7 @@ async def response_content(data: Response, **kwargs: Dict[str, Any]) -> LilyaRes
return cast(Callable[[Response, Dict[str, Any]], LilyaResponse], response_content)

def _get_response_handler(
self, cookies: "ResponseCookies", headers: Dict[str, Any], media_type: str
self, cookies: ResponseCookies, headers: Dict[str, Any], media_type: str
) -> Callable[[Response, Dict[str, Any]], LilyaResponse]:
"""
Creates a handler function for Response types.
Expand Down Expand Up @@ -372,7 +382,7 @@ async def response_content(data: Response, **kwargs: Dict[str, Any]) -> LilyaRes
return cast(Callable[[Response, Dict[str, Any]], LilyaResponse], response_content)

def _get_lilya_response_handler(
self, cookies: "ResponseCookies", headers: Dict[str, Any]
self, cookies: ResponseCookies, headers: Dict[str, Any]
) -> Callable[[LilyaResponse, Dict[str, Any]], LilyaResponse]:
"""
Creates a handler function for Lilya Responses.
Expand All @@ -386,7 +396,7 @@ def _get_lilya_response_handler(
"""

async def response_content(data: LilyaResponse, **kwargs: Dict[str, Any]) -> LilyaResponse:
_cookies = self.get_cookies(cookies, [])
_cookies = self.get_cookies(cookies)
_headers = {
**self.get_headers(headers),
**data.headers,
Expand All @@ -404,7 +414,7 @@ async def response_content(data: LilyaResponse, **kwargs: Dict[str, Any]) -> Lil

def _get_default_handler(
self,
cookies: "ResponseCookies",
cookies: ResponseCookies,
headers: Dict[str, Any],
media_type: str,
response_class: Any,
Expand All @@ -424,7 +434,7 @@ def _get_default_handler(

async def response_content(data: Any, **kwargs: Dict[str, Any]) -> LilyaResponse:
data = await self.get_response_data(data=data)
_cookies = self.get_cookies(cookies, [])
_cookies = self.get_cookies(cookies)
if isinstance(data, JSONResponse):
response = data
response.status_code = self.status_code
Expand Down Expand Up @@ -677,7 +687,7 @@ def dependency_names(self) -> Set[str]:
level_dependencies = (level.dependencies or {} for level in self.parent_levels)
return {name for level in level_dependencies for name in level.keys()}

def get_permissions(self) -> List["AsyncCallable"]:
def get_permissions(self) -> List[AsyncCallable]:
"""
Returns all the permissions in the handler scope from the ownership layers.
Expand All @@ -699,7 +709,7 @@ def get_permissions(self) -> List["AsyncCallable"]:
- The permissions are collected from all parent levels, ensuring that there are no duplicate permissions in the final list.
"""
if self._permissions is Void:
self._permissions: Union[List["Permission"], "VoidType"] = []
self._permissions: Union[List[Permission], VoidType] = []
for layer in self.parent_levels:
self._permissions.extend(layer.permissions or [])
self._permissions = cast(
Expand All @@ -708,7 +718,7 @@ def get_permissions(self) -> List["AsyncCallable"]:
)
return cast("List[AsyncCallable]", self._permissions)

def get_dependencies(self) -> "Dependencies":
def get_dependencies(self) -> Dependencies:
"""
Returns all dependencies of the handler function's starting from the parent levels.
Expand Down Expand Up @@ -737,7 +747,7 @@ def get_dependencies(self) -> "Dependencies":
)

if not self._dependencies or self._dependencies is Void:
self._dependencies: "Dependencies" = {}
self._dependencies: Dependencies = {}
for level in self.parent_levels:
for key, value in (level.dependencies or {}).items():
self.is_unique_dependency(
Expand All @@ -749,7 +759,7 @@ def get_dependencies(self) -> "Dependencies":
return self._dependencies

@staticmethod
def is_unique_dependency(dependencies: "Dependencies", key: str, injector: Inject) -> None:
def is_unique_dependency(dependencies: Dependencies, key: str, injector: Inject) -> None:
"""
Validates that a given inject has not been already defined under a different key in any of the levels.
Expand Down Expand Up @@ -784,7 +794,9 @@ def is_unique_dependency(dependencies: "Dependencies", key: str, injector: Injec
)

def get_cookies(
self, local_cookies: "ResponseCookies", other_cookies: "ResponseCookies"
self,
local_cookies: ResponseCookies | None,
other_cookies: ResponseCookies | None = None,
) -> List[Dict[str, Any]]: # pragma: no cover
"""
Returns a unique list of cookies.
Expand Down Expand Up @@ -820,18 +832,15 @@ def get_cookies(
This will output the list of normalized cookies.
"""
filtered_cookies = [*local_cookies]
for cookie in other_cookies:
if not any(cookie.key == c.key for c in filtered_cookies):
filtered_cookies.append(cookie)
normalized_cookies: List[Dict[str, Any]] = []
for cookie in filtered_cookies:
normalized_cookies.append(
cookie.model_dump(exclude_none=True, exclude={"description"})
)
return normalized_cookies
filtered_cookies: dict[str, Cookie] = {}
for cookie in chain(local_cookies or _empty, other_cookies or _empty):
filtered_cookies.setdefault(cookie.key, cookie)
return [
cookie.model_dump(exclude_none=True, exclude={"description"})
for cookie in filtered_cookies.values()
]

def get_headers(self, headers: "ResponseHeaders") -> Dict[str, Any]:
def get_headers(self, headers: ResponseHeaders) -> Dict[str, Any]:
"""
Returns a dictionary of response headers.
Expand Down Expand Up @@ -898,12 +907,12 @@ async def allow_connection(self, connection: "Connection") -> None: # pragma: n
- PermissionDenied: If the connection is not allowed.
"""
for permission in self.get_permissions():
awaitable: "BasePermission" = cast("BasePermission", await permission())
request: "Request" = cast("Request", connection)
awaitable: BasePermission = cast("BasePermission", await permission())
request: Request = cast("Request", connection)
handler = cast("APIGateHandler", self)
await continue_or_raise_permission_exception(request, handler, awaitable)

def get_security_schemes(self) -> List["SecurityScheme"]:
def get_security_schemes(self) -> List[SecurityScheme]:
"""
Returns a list of all security schemes associated with the handler.
Expand All @@ -924,7 +933,7 @@ def get_security_schemes(self) -> List["SecurityScheme"]:
- Each security scheme is represented by an instance of the SecurityScheme class.
- The SecurityScheme class has attributes such as name, type, scheme, bearer_format, in_, and name, which provide information about the security scheme.
"""
security_schemes: List["SecurityScheme"] = []
security_schemes: List[SecurityScheme] = []
for layer in self.parent_levels:
security_schemes.extend(layer.security or [])
return security_schemes
Expand Down Expand Up @@ -961,7 +970,7 @@ def get_handler_tags(self) -> List[str]:

return tags_clean if tags_clean else None

def get_interceptors(self) -> List["AsyncCallable"]:
def get_interceptors(self) -> List[AsyncCallable]:
"""
Returns a list of all the interceptors in the handler scope from the ownership layers.
If the interceptors have not been initialized, it initializes them by collecting interceptors from each parent level.
Expand All @@ -981,7 +990,7 @@ def get_interceptors(self) -> List["AsyncCallable"]:
- The AsyncCallable class provides a way to call the interceptor asynchronously.
"""
if self._interceptors is Void:
self._interceptors: Union[List["Interceptor"], "VoidType"] = []
self._interceptors: Union[List[Interceptor], VoidType] = []
for layer in self.parent_levels:
self._interceptors.extend(layer.interceptors or [])
self._interceptors = cast(
Expand Down Expand Up @@ -1015,5 +1024,5 @@ async def intercept(self, scope: "Scope", receive: "Receive", send: "Send") -> N
- The `intercept` method is responsible for executing the interceptors in the handler scope.
"""
for interceptor in self.get_interceptors():
awaitable: "EsmeraldInterceptor" = await interceptor()
awaitable: EsmeraldInterceptor = await interceptor()
await awaitable.intercept(scope, receive, send)

0 comments on commit 903c4e8

Please sign in to comment.