From 3367fa2cd02d11f3390a032cc766cdea4a1134cb Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 29 Dec 2024 15:27:38 +0100 Subject: [PATCH] Fix validation of parameters requiring a body (#467) - fix `data` and `payload` special keyword arguments so they are allowed when another method than one without a body (HEAD, GET) is available, - validate that if body-less methods are handled, form-like params are optional - validate that if only body-less methods are handled, form-like params are forbidden - fix missing call of the new validation method - fix pytest warnings --- docs/en/docs/release-notes.md | 2 ++ esmerald/routing/router.py | 43 +++++++++++++++++++----- tests/forms/test_forms_data.py | 50 +++++++++++++++------------- tests/forms/test_forms_payload.py | 50 +++++++++++++++------------- tests/forms/test_forms_route.py | 50 ++++++++++++++++++++++++++++ tests/routing/test_routing.py | 55 ++++++++++++++++++++++++++++++- 6 files changed, 197 insertions(+), 53 deletions(-) create mode 100644 tests/forms/test_forms_route.py diff --git a/docs/en/docs/release-notes.md b/docs/en/docs/release-notes.md index 5911e371..c7f388dd 100644 --- a/docs/en/docs/release-notes.md +++ b/docs/en/docs/release-notes.md @@ -20,9 +20,11 @@ hide: ### Fixed +- `data` and `payload` special kwargs are now allowed when a not-bodyless method is available for the handler. They default to None. - `bytes` won't be encoded as json when returned from a handler. This would unexpectly lead to a base64 encoding. - SessionConfig has a unneccessarily heavily restricted secret_key parameter. - Gracefully handle situations where cookies are None in `get_cookies`. +- Fix validation of parameters requiring a body. ## 3.6.1 diff --git a/esmerald/routing/router.py b/esmerald/routing/router.py index 1158632f..9265576d 100644 --- a/esmerald/routing/router.py +++ b/esmerald/routing/router.py @@ -55,6 +55,7 @@ from esmerald.interceptors.types import Interceptor from esmerald.openapi.datastructures import OpenAPIResponse from esmerald.openapi.utils import is_status_code_allowed +from esmerald.params import Form from esmerald.requests import Request from esmerald.responses import Response from esmerald.routing._internal import OpenAPIFieldInfoMixin @@ -66,7 +67,7 @@ from esmerald.transformers.utils import get_signature from esmerald.typing import Void, VoidType from esmerald.utils.constants import DATA, PAYLOAD, REDIRECT_STATUS_CODES, REQUEST, SOCKET -from esmerald.utils.helpers import is_async_callable, is_class_and_subclass +from esmerald.utils.helpers import is_async_callable, is_class_and_subclass, is_optional_union from esmerald.websockets import WebSocket, WebSocketClose if TYPE_CHECKING: # pragma: no cover @@ -1917,6 +1918,9 @@ def wrapper(func: Callable) -> Callable: return wrapper +_body_less_methods = frozenset({"GET", "HEAD", "OPTIONS", "TRACE"}) + + class HTTPHandler(Dispatcher, OpenAPIFieldInfoMixin, LilyaPath): __slots__ = ( "path", @@ -2218,17 +2222,39 @@ def validate_annotations(self) -> None: # pragma: no cover ]: self.media_type = MediaType.TEXT + def validate_bodyless_kwargs(self) -> None: + if _body_less_methods.isdisjoint(self.methods): + return + body_less_only = _body_less_methods.issuperset(self.methods) + for special in [DATA, PAYLOAD]: + if special in self.handler_signature.parameters: + if body_less_only: + raise ImproperlyConfigured( + f"'{special}' argument unsupported when only body-less methods like 'GET' and 'HEAD' are handled" + ) + elif not is_optional_union(self.handler_signature.parameters[special].annotation): + raise ImproperlyConfigured( + f"'{special}' argument must be optional when body-less methods like 'GET' and 'HEAD' are handled" + ) + for parameter_name, parameter in self.handler_signature.parameters.items(): + # don't check twice + if parameter_name == DATA or parameter_name == PAYLOAD: + continue + if isinstance(parameter.default, Form): + if body_less_only: + raise ImproperlyConfigured( + f"'{special}' uses Form() which is unsupported when only body-less methods " + "like 'GET' and 'HEAD' are handled" + ) + elif not is_optional_union(parameter.annotation): + raise ImproperlyConfigured( + f"'{special}' argument must be optional when body-less methods like 'GET' and 'HEAD' are handled" + ) + def validate_reserved_kwargs(self) -> None: # pragma: no cover """ Validates if special words are in the signature. """ - if DATA in self.handler_signature.parameters and "GET" in self.methods: - raise ImproperlyConfigured("'data' argument is unsupported for 'GET' request handlers") - - if PAYLOAD in self.handler_signature.parameters and "GET" in self.methods: - raise ImproperlyConfigured( - "'payload' argument is unsupported for 'GET' request handlers" - ) if SOCKET in self.handler_signature.parameters: raise ImproperlyConfigured("The 'socket' argument is not supported with http handlers") @@ -2236,6 +2262,7 @@ def validate_reserved_kwargs(self) -> None: # pragma: no cover def validate_handler(self) -> None: self.check_handler_function() self.validate_annotations() + self.validate_bodyless_kwargs() self.validate_reserved_kwargs() async def to_response(self, app: "Esmerald", data: Any) -> LilyaResponse: diff --git a/tests/forms/test_forms_data.py b/tests/forms/test_forms_data.py index a5918265..beec91c1 100644 --- a/tests/forms/test_forms_data.py +++ b/tests/forms/test_forms_data.py @@ -1,10 +1,12 @@ from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, Optional +import pytest from pydantic import BaseModel from pydantic.dataclasses import dataclass as pydantic_dataclass -from esmerald import Form, Gateway, post +from esmerald import Form, Gateway, post, route +from esmerald.exceptions import ImproperlyConfigured from esmerald.testclient import create_client @@ -25,27 +27,11 @@ class UserModel(BaseModel): name: str -@post("/form") -async def test_form(data: Any = Form()) -> Dict[str, str]: - return {"name": data["name"]} - - -@post("/complex-form-pydantic") -async def test_complex_form_pydantic_dataclass(data: User = Form()) -> User: - return data - - -@post("/complex-form-dataclass") -async def test_complex_form_dataclass(data: UserOut = Form()) -> UserOut: - return data - - -@post("/complex-form-basemodel") -async def test_complex_form_basemodel(data: UserModel = Form()) -> UserModel: - return data - - def test_send_form(test_client_factory): + @post("/form") + async def test_form(data: Any = Form()) -> Dict[str, str]: + return {"name": data["name"]} + data = {"name": "Test"} with create_client(routes=[Gateway(handler=test_form)]) as client: @@ -56,6 +42,10 @@ def test_send_form(test_client_factory): def test_send_complex_form_pydantic_dataclass(test_client_factory): + @post("/complex-form-pydantic") + async def test_complex_form_pydantic_dataclass(data: User = Form()) -> User: + return data + data = {"id": 1, "name": "Test"} with create_client( routes=[Gateway(handler=test_complex_form_pydantic_dataclass)], @@ -67,6 +57,10 @@ def test_send_complex_form_pydantic_dataclass(test_client_factory): def test_send_complex_form_normal_dataclass(test_client_factory): + @post("/complex-form-dataclass") + async def test_complex_form_dataclass(data: UserOut = Form()) -> UserOut: + return data + data = {"id": 1, "name": "Test"} with create_client( routes=[Gateway(handler=test_complex_form_dataclass)], @@ -78,6 +72,10 @@ def test_send_complex_form_normal_dataclass(test_client_factory): def test_send_complex_form_base_model(test_client_factory): + @post("/complex-form-basemodel") + async def test_complex_form_basemodel(data: UserModel = Form()) -> UserModel: + return data + data = {"id": 1, "name": "Test"} with create_client( routes=[Gateway(handler=test_complex_form_basemodel)], @@ -86,3 +84,11 @@ def test_send_complex_form_base_model(test_client_factory): response = client.post("/complex-form-basemodel", data=data) assert response.status_code == 201, response.text assert response.json() == {"id": 1, "name": "Test"} + + +def test_get_and_head_data(): + with pytest.raises(ImproperlyConfigured): + + @route(methods=["GET", "HEAD"]) + async def start(data: Optional[UserModel]) -> bytes: + return b"hello world" diff --git a/tests/forms/test_forms_payload.py b/tests/forms/test_forms_payload.py index 5e35aff5..79c450b1 100644 --- a/tests/forms/test_forms_payload.py +++ b/tests/forms/test_forms_payload.py @@ -1,10 +1,12 @@ from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, Optional +import pytest from pydantic import BaseModel from pydantic.dataclasses import dataclass as pydantic_dataclass -from esmerald import Form, Gateway, post +from esmerald import Form, Gateway, post, route +from esmerald.exceptions import ImproperlyConfigured from esmerald.testclient import create_client @@ -25,27 +27,11 @@ class UserModel(BaseModel): name: str -@post("/form") -async def test_form(payload: Any = Form()) -> Dict[str, str]: - return {"name": payload["name"]} - - -@post("/complex-form-pydantic") -async def test_complex_form_pydantic_dataclass(payload: User = Form()) -> User: - return payload - - -@post("/complex-form-dataclass") -async def test_complex_form_dataclass(payload: UserOut = Form()) -> UserOut: - return payload - - -@post("/complex-form-basemodel") -async def test_complex_form_basemodel(payload: UserModel = Form()) -> UserModel: - return payload - - def test_send_form(test_client_factory): + @post("/form") + async def test_form(payload: Any = Form()) -> Dict[str, str]: + return {"name": payload["name"]} + payload = {"name": "Test"} with create_client(routes=[Gateway(handler=test_form)]) as client: @@ -56,6 +42,10 @@ def test_send_form(test_client_factory): def test_send_complex_form_pydantic_dataclass(test_client_factory): + @post("/complex-form-pydantic") + async def test_complex_form_pydantic_dataclass(payload: User = Form()) -> User: + return payload + payload = {"id": 1, "name": "Test"} with create_client( routes=[Gateway(handler=test_complex_form_pydantic_dataclass)], @@ -67,6 +57,10 @@ def test_send_complex_form_pydantic_dataclass(test_client_factory): def test_send_complex_form_normal_dataclass(test_client_factory): + @post("/complex-form-dataclass") + async def test_complex_form_dataclass(payload: UserOut = Form()) -> UserOut: + return payload + payload = {"id": 1, "name": "Test"} with create_client( routes=[Gateway(handler=test_complex_form_dataclass)], @@ -78,6 +72,10 @@ def test_send_complex_form_normal_dataclass(test_client_factory): def test_send_complex_form_base_model(test_client_factory): + @post("/complex-form-basemodel") + async def test_complex_form_basemodel(payload: UserModel = Form()) -> UserModel: + return payload + payload = {"id": 1, "name": "Test"} with create_client( routes=[Gateway(handler=test_complex_form_basemodel)], @@ -86,3 +84,11 @@ def test_send_complex_form_base_model(test_client_factory): response = client.post("/complex-form-basemodel", data=payload) assert response.status_code == 201, response.text assert response.json() == {"id": 1, "name": "Test"} + + +def test_get_and_head_payload(): + with pytest.raises(ImproperlyConfigured): + + @route(methods=["GET", "HEAD"]) + async def start(payload: Optional[UserModel]) -> bytes: + return b"hello world" diff --git a/tests/forms/test_forms_route.py b/tests/forms/test_forms_route.py new file mode 100644 index 00000000..cd8287ce --- /dev/null +++ b/tests/forms/test_forms_route.py @@ -0,0 +1,50 @@ +from typing import Optional, Union + +import pytest +from pydantic import BaseModel + +from esmerald import Esmerald, Form, Request +from esmerald.exceptions import ImproperlyConfigured +from esmerald.routing.gateways import Gateway +from esmerald.routing.handlers import route +from esmerald.testclient import EsmeraldTestClient + + +class Model(BaseModel): + id: str + + +def test_get_and_post(): + @route(methods=["GET", "POST"]) + async def start(request: Request, form: Union[Model, None] = Form()) -> bytes: + return b"hello world" + + app = Esmerald( + debug=True, + routes=[Gateway("/", handler=start)], + ) + client = EsmeraldTestClient(app) + response = client.get("/") + assert response.status_code == 200 + + +def test_get_and_post_optional(): + @route(methods=["GET", "POST"]) + async def start(request: Request, form: Optional[Model] = Form()) -> bytes: + return b"hello world" + + app = Esmerald( + debug=True, + routes=[Gateway("/", handler=start)], + ) + client = EsmeraldTestClient(app) + response = client.get("/") + assert response.status_code == 200 + + +def test_get_and_head_form(): + with pytest.raises(ImproperlyConfigured): + + @route(methods=["GET", "HEAD"]) + async def start(form: Optional[Model] = Form()) -> bytes: + return b"hello world" diff --git a/tests/routing/test_routing.py b/tests/routing/test_routing.py index f639afc2..a23cec1d 100644 --- a/tests/routing/test_routing.py +++ b/tests/routing/test_routing.py @@ -1,6 +1,7 @@ import contextlib import uuid from dataclasses import dataclass +from typing import Optional import pytest from lilya.responses import JSONResponse, PlainText, Response as LilyaResponse @@ -10,13 +11,14 @@ from esmerald.applications import Esmerald from esmerald.enums import MediaType +from esmerald.exceptions import ImproperlyConfigured from esmerald.permissions import AllowAny, DenyAll from esmerald.requests import Request from esmerald.responses import Response from esmerald.responses.encoders import UJSONResponse from esmerald.routing.apis.views import APIView from esmerald.routing.gateways import Gateway, WebSocketGateway -from esmerald.routing.handlers import get, post, put, websocket +from esmerald.routing.handlers import get, post, put, route, websocket from esmerald.routing.router import Include, Router from esmerald.testclient import create_client @@ -1020,3 +1022,54 @@ def another_user(data: UserOut) -> UserOut: assert response.status_code == 200 assert response.json() == {"name": "test", "email": "esmerald@esmerald.dev"} + + +def test_get_and_post_data(test_app_client_factory): + @route(path="/another-user", status_code=200, methods=["GET", "POST"]) + def another_user(data: Optional[UserOut]) -> Optional[UserOut]: + return data + + data = {"name": "test", "email": "esmerald@esmerald.dev"} + app = Esmerald(routes=[Gateway(handler=another_user)]) + client = test_app_client_factory(app) + response = client.get("/another-user") + assert response.status_code == 200 + assert response.text == "" + response = client.post("/another-user", json=data) + + assert response.status_code == 200 + assert response.json() == {"name": "test", "email": "esmerald@esmerald.dev"} + + +def test_get_and_post_payload(test_app_client_factory): + @route(path="/another-user", status_code=200, methods=["GET", "POST"]) + def another_user(payload: Optional[UserOut]) -> Optional[UserOut]: + return payload + + data = {"name": "test", "email": "esmerald@esmerald.dev"} + app = Esmerald(routes=[Gateway(handler=another_user)]) + + client = test_app_client_factory(app) + response = client.get("/another-user") + assert response.status_code == 200 + assert response.text == "" + response = client.post("/another-user", json=data) + + assert response.status_code == 200 + assert response.json() == {"name": "test", "email": "esmerald@esmerald.dev"} + + +def test_get_and_head_data(): + with pytest.raises(ImproperlyConfigured): + + @route(path="/another-user", status_code=200, methods=["GET", "HEAD"]) + def another_user(data: UserOut) -> UserOut: + return data + + +def test_get_and_head_payload(): + with pytest.raises(ImproperlyConfigured): + + @route(path="/another-user", status_code=200, methods=["GET", "HEAD"]) + def another_user(payload: UserOut) -> UserOut: + return payload