From 0cbd57181b8d9e4a4cb99080d9fd7a7ed945247a Mon Sep 17 00:00:00 2001 From: tarsil Date: Wed, 26 Oct 2022 15:08:11 +0100 Subject: [PATCH] Fix add_router. . Fix handling with path from import routes. --- esmerald/applications.py | 10 ++++++-- esmerald/backgound.py | 4 +-- esmerald/kwargs.py | 43 +++++++++----------------------- esmerald/parsers.py | 12 +++------ esmerald/routing/base.py | 48 +++++++++--------------------------- esmerald/routing/gateways.py | 7 +++++- esmerald/signature.py | 15 +++-------- esmerald/utils/helpers.py | 4 +-- esmerald/utils/sync.py | 4 +-- 9 files changed, 47 insertions(+), 100 deletions(-) diff --git a/esmerald/applications.py b/esmerald/applications.py index 4d1a865b..c985d20f 100644 --- a/esmerald/applications.py +++ b/esmerald/applications.py @@ -233,6 +233,9 @@ def __init__( configurations=self.scheduler_configurations, ) + self.activate_openapi() + + def activate_openapi(self) -> None: if self.openapi_config and self.enable_openapi: self.openapi_schema = self.openapi_config.create_openapi_schema_model(self) gateway = gateways.Gateway(handler=self.openapi_config.openapi_apiview) @@ -324,9 +327,9 @@ def add_router(self, router: "Router"): ) if self.on_startup: - self.on_startup.append(router.on_startup) + self.on_startup.extend(router.on_startup) if self.on_shutdown: - self.on_shutdown.append(router.on_shutdown) + self.on_shutdown.extend(router.on_shutdown) self.router.routes.append( gateway( @@ -338,9 +341,12 @@ def add_router(self, router: "Router"): permissions=route.permissions, handler=route.handler, parent=self.router, + is_from_router=True, ) ) + self.activate_openapi() + def get_default_exception_handlers(self) -> None: """ Default exception handlers added to the application. diff --git a/esmerald/backgound.py b/esmerald/backgound.py index c4a7a387..ae71cf50 100644 --- a/esmerald/backgound.py +++ b/esmerald/backgound.py @@ -10,9 +10,7 @@ class BackgroundTask(StarletteBackgroundTask): - def __init__( - self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs - ) -> None: + def __init__(self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs) -> None: super().__init__(func, *args, **kwargs) diff --git a/esmerald/kwargs.py b/esmerald/kwargs.py index 74a1a82e..0f323084 100644 --- a/esmerald/kwargs.py +++ b/esmerald/kwargs.py @@ -53,9 +53,7 @@ class ParameterDefinition(NamedTuple): class Dependency: __slots__ = ("key", "inject", "dependencies") - def __init__( - self, key: str, inject: Inject, dependencies: List["Dependency"] - ) -> None: + def __init__(self, key: str, inject: Inject, dependencies: List["Dependency"]) -> None: self.key = key self.inject = inject self.dependencies = dependencies @@ -202,9 +200,7 @@ def create_for_signature_model( signature_model_fields=signature_model.__fields__, ) - expected_path_parameters = { - p for p in param_definitions if p.param_type == ParamType.PATH - } + expected_path_parameters = {p for p in param_definitions if p.param_type == ParamType.PATH} expected_header_parameters = { p for p in param_definitions if p.param_type == ParamType.HEADER } @@ -258,9 +254,7 @@ def create_for_signature_model( expected_form_data=expected_form_data, dependency_kwargs_model=dependency_kwargs_model, ) - expected_reserved_kwargs.update( - dependency_kwargs_model.expected_reserved_kwargs - ) + expected_reserved_kwargs.update(dependency_kwargs_model.expected_reserved_kwargs) return KwargsModel( expected_form_data=expected_form_data, @@ -269,9 +263,7 @@ def create_for_signature_model( expected_query_params=expected_query_parameters, expected_cookie_params=expected_cookie_parameters, expected_header_params=expected_header_parameters, - expected_reserved_kwargs=cast( - "Set[ReservedKwargs]", expected_reserved_kwargs - ), + expected_reserved_kwargs=cast("Set[ReservedKwargs]", expected_reserved_kwargs), sequence_query_parameter_names=sequence_query_parameter_names, is_data_optional=is_optional(signature_model.__fields__["data"]) if "data" in expected_reserved_kwargs @@ -280,8 +272,7 @@ def create_for_signature_model( def to_kwargs(self, connection: Union["WebSocket", "Request"]) -> Dict[str, Any]: connection_query_params = { - k: self._sequence_or_scalar_param(k, v) - for k, v in connection.query_params.items() + k: self._sequence_or_scalar_param(k, v) for k, v in connection.query_params.items() } query_params = self._collect_params( @@ -323,9 +314,7 @@ def to_kwargs(self, connection: Union["WebSocket", "Request"]) -> Dict[str, Any] if "socket" in self.expected_reserved_kwargs: reserved_kwargs["socket"] = connection if "data" in self.expected_reserved_kwargs: - reserved_kwargs["data"] = self._get_request_data( - request=cast("Request", connection) - ) + reserved_kwargs["data"] = self._get_request_data(request=cast("Request", connection)) return { **reserved_kwargs, **path_params, @@ -340,17 +329,13 @@ def _collect_params( ) -> Dict[str, Any]: """Collects request params, checking for missing required values.""" missing_params = [ - p.field_alias - for p in expected - if p.is_required and p.field_alias not in params + p.field_alias for p in expected if p.is_required and p.field_alias not in params ] if missing_params: raise ValidationErrorException( f"Missing required parameter(s) {', '.join(missing_params)} for url {url}" ) - return { - p.field_name: params.get(p.field_alias, p.default_value) for p in expected - } + return {p.field_name: params.get(p.field_alias, p.default_value) for p in expected} async def resolve_dependency( self, @@ -369,9 +354,7 @@ async def resolve_dependency( return await dependency.inject(**dependency_kwargs) @classmethod - def _create_dependency_graph( - cls, key: str, dependencies: Dict[str, Inject] - ) -> Dependency: + def _create_dependency_graph(cls, key: str, dependencies: Dict[str, Inject]) -> Dependency: inject = dependencies[key] sub_dependency_keys = [ k for k in get_signature_model(inject).__fields__ if k in dependencies @@ -395,9 +378,7 @@ def _create_parameter_definition( ) -> ParameterDefinition: extra = field_info.extra is_required = extra.get(EXTRA_KEY_REQUIRED, True) - default_value = ( - field_info.default if field_info.default is not Undefined else None - ) + default_value = field_info.default if field_info.default is not Undefined else None field_alias = extra.get(ParamType.QUERY) or field_name param_type = getattr(field_info, "in_", ParamType.QUERY) @@ -483,9 +464,7 @@ def _validate_raw_kwargs( f"The following kwargs have been used: {', '.join(used_reserved_kwargs)}" ) - def _sequence_or_scalar_param( - self, key: str, value: List[str] - ) -> Union[str, List[str]]: + def _sequence_or_scalar_param(self, key: str, value: List[str]) -> Union[str, List[str]]: return ( value[0] if key not in self.sequence_query_parameter_names and len(value) == 1 diff --git a/esmerald/parsers.py b/esmerald/parsers.py index 4b2139ad..a2b2a122 100644 --- a/esmerald/parsers.py +++ b/esmerald/parsers.py @@ -20,9 +20,7 @@ _false_values = {"False", "false"} -def _query_param_reducer( - acc: Dict[str, List[str]], cur: Tuple[str, str] -) -> Dict[str, List[str]]: +def _query_param_reducer(acc: Dict[str, List[str]], cur: Tuple[str, str]) -> Dict[str, List[str]]: key, value = cur if value in _true_values: @@ -43,18 +41,14 @@ def parse_query_params(connection: "HTTPConnection") -> Dict[str, Any]: return reduce( _query_param_reducer, parse_qsl( - query_string - if isinstance(query_string, str) - else query_string.decode("latin-1"), + query_string if isinstance(query_string, str) else query_string.decode("latin-1"), keep_blank_values=True, ), {}, ) -def parse_form_data( - media_type: "EncodingType", form_data: "FormData", field: "ModelField" -) -> Any: +def parse_form_data(media_type: "EncodingType", form_data: "FormData", field: "ModelField") -> Any: values_dict: Dict[str, Any] = {} for key, value in form_data.multi_items(): if not isinstance(value, (UploadFile, StarletteUploadFile)): diff --git a/esmerald/routing/base.py b/esmerald/routing/base.py index d184d22b..d1bf4daa 100644 --- a/esmerald/routing/base.py +++ b/esmerald/routing/base.py @@ -84,9 +84,7 @@ class PathParameterSchema(TypedDict): class OpenAPIDefinitionMixin: - def parse_path( - self, path: str - ) -> Tuple[str, str, List[Union[str, PathParameterSchema]]]: + def parse_path(self, path: str) -> Tuple[str, str, List[Union[str, PathParameterSchema]]]: """ Using the Starlette CONVERTORS and the application registered convertors, transforms the path into a PathParameterSchema used for the OpenAPI definition. @@ -183,9 +181,7 @@ def create_response_handler( status_code: Optional[int] = None, media_type: Optional[str] = MediaType.TEXT, ) -> "AsyncAnyCallable": - async def response_content( - data: Response, **kwargs: Dict[str, Any] - ) -> StarletteResponse: + async def response_content(data: Response, **kwargs: Dict[str, Any]) -> StarletteResponse: _cookies = self.get_cookies(data.cookies, cookies) _headers = { **self.get_headers(headers), @@ -213,9 +209,7 @@ def create_json_response_handler( ) -> "AsyncAnyCallable": """Creates a handler function for Esmerald JSON responses""" - async def response_content( - data: Response, **kwargs: Dict[str, Any] - ) -> StarletteResponse: + async def response_content(data: Response, **kwargs: Dict[str, Any]) -> StarletteResponse: if status_code: data.status_code = status_code return data @@ -260,9 +254,7 @@ def create_handler( response_class: "ResponseType", status_code: int, ) -> "AsyncAnyCallable": - async def response_content( - data: Any, **kwargs: Dict[str, Any] - ) -> StarletteResponse: + async def response_content(data: Any, **kwargs: Dict[str, Any]) -> StarletteResponse: data = await self.get_response_data(data=data) _cookies = self.get_cookies(cookies, []) @@ -372,18 +364,14 @@ def get_response_handler(self) -> Callable[[Any], Awaitable[StarletteResponse]]: """ if self._response_handler is Void: media_type = ( - self.media_type.value - if isinstance(self.media_type, Enum) - else self.media_type + self.media_type.value if isinstance(self.media_type, Enum) else self.media_type ) response_class = self.get_response_class() headers = self.get_response_headers() cookies = self.get_response_cookies() - if is_class_and_subclass( - self.signature.return_annotation, ResponseContainer - ): + if is_class_and_subclass(self.signature.return_annotation, ResponseContainer): handler = self.create_response_container_handler( cookies=cookies, media_type=self.media_type, @@ -394,9 +382,7 @@ def get_response_handler(self) -> Callable[[Any], Awaitable[StarletteResponse]]: self.signature.return_annotation, (JSONResponse, ORJSONResponse, UJSONResponse), ): - handler = self.create_json_response_handler( - status_code=self.status_code - ) + handler = self.create_json_response_handler(status_code=self.status_code) elif is_class_and_subclass(self.signature.return_annotation, Response): handler = self.create_response_handler( cookies=cookies, @@ -404,9 +390,7 @@ def get_response_handler(self) -> Callable[[Any], Awaitable[StarletteResponse]]: media_type=self.media_type, headers=headers, ) - elif is_class_and_subclass( - self.signature.return_annotation, StarletteResponse - ): + elif is_class_and_subclass(self.signature.return_annotation, StarletteResponse): handler = self.create_starlette_response_handler( cookies=cookies, media_type=self.media_type, @@ -454,9 +438,7 @@ def normalised_path_params(self) -> List[Dict[str, str]]: Gets the path parameters in a PathParameterSchema format. """ path_components = self.parse_path(self.path) - parameters = [ - component for component in path_components if isinstance(component, dict) - ] + parameters = [component for component in path_components if isinstance(component, dict)] return parameters @property @@ -475,9 +457,7 @@ def parent_layers(self) -> List[Union[T, "ParentType"]]: def dependency_names(self) -> Set[str]: """A unique set of all dependency names provided in the handlers parent layers.""" - layered_dependencies = ( - layer.dependencies or {} for layer in self.parent_layers - ) + layered_dependencies = (layer.dependencies or {} for layer in self.parent_layers) return {name for layer in layered_dependencies for name in layer.keys()} def resolve_permissions(self) -> List["Permission"]: @@ -515,9 +495,7 @@ def get_dependencies(self) -> Dict[str, Inject]: return cast("Dict[str, Inject]", self._dependencies) @staticmethod - def has_dependency_unique( - dependencies: Dict[str, Inject], key: str, injector: Inject - ) -> None: + def has_dependency_unique(dependencies: Dict[str, Inject], key: str, injector: Inject) -> None: """ Validates that a given inject has not been already defined under a different key in any of the layers. @@ -551,9 +529,7 @@ def get_cookies( filtered_cookies.append(cookie) normalized_cookies: List[Dict[str, Any]] = [] for cookie in filtered_cookies: - normalized_cookies.append( - cookie.dict(exclude_none=True, exclude={"description"}) - ) + normalized_cookies.append(cookie.dict(exclude_none=True, exclude={"description"})) return normalized_cookies def get_headers(self, headers: "ResponseHeaders") -> Dict[str, Any]: diff --git a/esmerald/routing/gateways.py b/esmerald/routing/gateways.py index 40d36c38..873da503 100644 --- a/esmerald/routing/gateways.py +++ b/esmerald/routing/gateways.py @@ -43,12 +43,17 @@ def __init__( permissions: Optional["Permission"] = None, exception_handlers: Optional["ExceptionHandlers"] = None, deprecated: Optional[bool] = None, + is_from_router: bool = False, ) -> None: if not path: path = "/" if is_class_and_subclass(handler, APIView): handler = handler(parent=self) - self.path = clean_path(path + handler.path) + + if not is_from_router: + self.path = clean_path(path + handler.path) + else: + self.path = clean_path(path) self.methods = getattr(handler, "methods", None) if not name: diff --git a/esmerald/signature.py b/esmerald/signature.py index 485919ed..9cd57b27 100644 --- a/esmerald/signature.py +++ b/esmerald/signature.py @@ -93,15 +93,9 @@ def is_server_error(cls, error: "ErrorDict") -> bool: return error["loc"][-1] in cls.dependency_names @staticmethod - def get_connection_method_and_url( - connection: Union[Request, WebSocket] - ) -> Tuple[str, "URL"]: + def get_connection_method_and_url(connection: Union[Request, WebSocket]) -> Tuple[str, "URL"]: """Extract method and URL from Request or WebSocket.""" - method = ( - ScopeType.WEBSOCKET - if isinstance(connection, WebSocket) - else connection.method - ) + method = ScopeType.WEBSOCKET if isinstance(connection, WebSocket) else connection.method return method, connection.url @@ -200,9 +194,8 @@ def signature_parameters(self) -> Generator[SignatureParameter, None, None]: yield SignatureParameter(self.fn_name, name, parameter) def should_skip_parameter_validation(self, parameter: SignatureParameter) -> bool: - return ( - parameter.name in SKIP_VALIDATION_NAMES - or should_skip_dependency_validation(parameter.default) + return parameter.name in SKIP_VALIDATION_NAMES or should_skip_dependency_validation( + parameter.default ) def create_signature_model(self) -> Type[SignatureModel]: diff --git a/esmerald/utils/helpers.py b/esmerald/utils/helpers.py index 4f735439..0a864737 100644 --- a/esmerald/utils/helpers.py +++ b/esmerald/utils/helpers.py @@ -23,9 +23,7 @@ def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[ while isinstance(value, functools.partial): value = value.func # type: ignore[unreachable] - return asyncio.iscoroutinefunction(value) or asyncio.iscoroutinefunction( - value.__call__ - ) # ty + return asyncio.iscoroutinefunction(value) or asyncio.iscoroutinefunction(value.__call__) # ty def is_class_and_subclass(value: typing.Any, type_: typing.Any) -> bool: diff --git a/esmerald/utils/sync.py b/esmerald/utils/sync.py index a7bb02a8..26673d4e 100644 --- a/esmerald/utils/sync.py +++ b/esmerald/utils/sync.py @@ -23,9 +23,7 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return await self.fn(*args, **kwargs) -def as_async_callable_list( - value: Union[Callable, List[Callable]] -) -> List[AsyncCallable]: +def as_async_callable_list(value: Union[Callable, List[Callable]]) -> List[AsyncCallable]: if not isinstance(value, list): return [AsyncCallable(value)] return [AsyncCallable(v) for v in value]