From 164e3aa36bd2cc8273f051c146ef90ac99237e7b Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 9 Oct 2023 10:03:20 +0100 Subject: [PATCH 1/4] Add extra_allowed instead --- esmerald/routing/apis/views.py | 12 +++++++++++- esmerald/utils/helpers.py | 17 +++-------------- tests/routing/test_api_views.py | 14 +++++++++++++- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/esmerald/routing/apis/views.py b/esmerald/routing/apis/views.py index 878232a0..75d09e9b 100644 --- a/esmerald/routing/apis/views.py +++ b/esmerald/routing/apis/views.py @@ -13,7 +13,10 @@ class SimpleAPIMeta(type): def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any: """ Making sure the `http_allowed_methods` are extended if inheritance happens - in the subclass + in the subclass. + + The `http_allowed_methods` is the default for each type of generic but to allow + extra allowed methods, the `extra_allowed` must be added. """ view = super().__new__ @@ -36,6 +39,13 @@ def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any: ): simple_view.http_allowed_methods.extend(base.http_allowed_methods) + if hasattr(simple_view, "extra_allowed"): + assert isinstance( + simple_view.extra_allowed, list + ), "`extra_allowed` must be a list of strings allowed." + + simple_view.http_allowed_methods.extend(simple_view.extra_allowed) + allowed_methods: Set[str] = {method.lower() for method in simple_view.http_allowed_methods} simple_view.http_allowed_methods = list(allowed_methods) message = ", ".join(allowed_methods) diff --git a/esmerald/utils/helpers.py b/esmerald/utils/helpers.py index c940d5d4..2e7be6cf 100644 --- a/esmerald/utils/helpers.py +++ b/esmerald/utils/helpers.py @@ -1,12 +1,11 @@ -import asyncio -import functools import sys import typing from inspect import isclass -from typing import Any, Awaitable, Callable, TypeVar, Union +from typing import Any, Union import slugify -from typing_extensions import ParamSpec, TypeGuard, get_args, get_origin +from starlette._utils import is_async_callable as is_async_callable +from typing_extensions import get_args, get_origin if sys.version_info >= (3, 10): from types import UnionType @@ -15,16 +14,6 @@ else: # pragma: no cover UNION_TYPES = {Union} -P = ParamSpec("P") -T = TypeVar("T") - - -def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[T]]]: - while isinstance(value, functools.partial): - value = value.func # type: ignore[unreachable] - - return asyncio.iscoroutinefunction(value) or asyncio.iscoroutinefunction(value.__call__) # type: ignore - def is_class_and_subclass(value: typing.Any, _type: typing.Any) -> bool: original = get_origin(value) diff --git a/tests/routing/test_api_views.py b/tests/routing/test_api_views.py index adfbcaa9..7d83ba40 100644 --- a/tests/routing/test_api_views.py +++ b/tests/routing/test_api_views.py @@ -141,7 +141,7 @@ async def get(self) -> str: @pytest.mark.parametrize("value,method", [("create_user", "post"), ("read_item", "get")]) def test_all_api_view_custom(test_client_factory, value, method): class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView): - http_allowed_methods: List[str] = ["create_user", "read_item"] + extra_allowed: List[str] = ["create_user", "read_item"] @post(status_code=200) async def create_user(self) -> str: @@ -155,3 +155,15 @@ async def read_item(self) -> str: response = getattr(client, method)("/") assert response.status_code == 200 assert response.json() == f"home {value}" + + +@pytest.mark.parametrize( + "value", + [("create_user",), {"create_user"}, {"name": "create_user"}], + ids=["tuple", "set", "dict"], +) +def test_all_api_view_custom_error(test_client_factory, value): + with pytest.raises(AssertionError): + + class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView): + extra_allowed: List[str] = ("create_user", "read_item") From 11cd6b171e2cd7a460a861aa28f507ee396a3335 Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 9 Oct 2023 11:10:16 +0100 Subject: [PATCH 2/4] Split simpleapiview metaclass --- docs/routing/apiview.md | 32 +++++++++++-- docs_src/routing/generics/allowed.py | 2 +- docs_src/routing/generics/important.py | 15 ++++++ esmerald/routing/apis/_metaclasses.py | 58 +++++++++++++++++++++++ esmerald/routing/apis/views.py | 65 ++++---------------------- tests/routing/test_api_views.py | 25 ++++++++++ 6 files changed, 135 insertions(+), 62 deletions(-) create mode 100644 docs_src/routing/generics/important.py create mode 100644 esmerald/routing/apis/_metaclasses.py diff --git a/docs/routing/apiview.md b/docs/routing/apiview.md index 41f50d20..7d22a314 100644 --- a/docs/routing/apiview.md +++ b/docs/routing/apiview.md @@ -97,8 +97,8 @@ allows the creation of apis where the function name can be whatever you desire l So what does that mean? Means **you can only perform operations where the function name coincides with the http verb**. For example, `get`, `put`, `post` etc... -If you attempt to create a functionm where the name differs from a http verb, -an `ImproperlyConfigured` exception is raised. +If you attempt to create a function where the name differs from a http verb, +an `ImproperlyConfigured` exception is raised **unless the `extra_allowed` is declared**. The available http verbs are: @@ -113,6 +113,27 @@ The available http verbs are: Basically the same availability as the [handlers](./handlers.md). +### Important + +The generics **enforce** the name matching of the functions with the handlers. That means, if +you use a `ReadAPIView` that only allows the `get` and you use the wrong [handlers](./handlers.md) +on the top of it, for example a [post](./handlers.md#post), an `ImproperlyConfigured` exception +will be raised. + +Let us see what this means. + +```python hl_lines="13-14" +{!> ../docs_src/routing/generics/important.py !} +``` + +As you can see, the handler `post()` does not match the function name `get`. **It should always match**. + +An easy way of knowing this is simple, when it comes to the available http verbs, the function name +**should always match the handler**. + +Are there any exception? Yes but not for these specific cases, the exceptions are called +[extra_allowed](#extra_allowed) but more details about this later on. + ### SimpleAPIView This is the base of all generics, subclassing from this class will allow you to perform all the @@ -188,11 +209,12 @@ What if you want to combine them all? Of course you also can. **Combining them all is the same as using the [SimpleAPIView](#simpleapiview)**. -### http_allowed_methods +### extra_allowed All the generics subclass the [SimpleAPIView](#simpleapiview) as mentioned before and that superclass uses the `http_allowed_methods` to verify which methods are allowed or not to be passed inside -the API object. +the API object but also check if there is any `extra_allowed` list with any extra functions you +would like the view to deliver. This means that if you want to add a `read_item()` function to any of the generics you also do it easily. @@ -202,7 +224,7 @@ generics you also do it easily. ``` As you can see, to make it happen you would need to declare the function name inside the -`http_allowed_methods` to make sure that an `ImproperlyConfigured` is not raised. +`extra_allowed` to make sure that an `ImproperlyConfigured` is not raised. ## What to choose diff --git a/docs_src/routing/generics/allowed.py b/docs_src/routing/generics/allowed.py index f7d42542..95111825 100644 --- a/docs_src/routing/generics/allowed.py +++ b/docs_src/routing/generics/allowed.py @@ -10,7 +10,7 @@ class UserAPI(CreateAPIView): to be used by default. """ - http_allowed_methods: List[str] = ["read_item"] + extra_allowed: List[str] = ["read_item"] @post() def post(self) -> str: diff --git a/docs_src/routing/generics/important.py b/docs_src/routing/generics/important.py new file mode 100644 index 00000000..70f9322a --- /dev/null +++ b/docs_src/routing/generics/important.py @@ -0,0 +1,15 @@ +from typing import List + +from esmerald import get, patch, post, put +from esmerald.routing.apis.generics import CreateAPIView + + +class UserAPI(CreateAPIView): + """ + ImproperlyConfigured will be raised as the handler `post()` + name does not match the function name `post`. + """ + + @post() + def get(self) -> str: + ... diff --git a/esmerald/routing/apis/_metaclasses.py b/esmerald/routing/apis/_metaclasses.py new file mode 100644 index 00000000..1eeca81f --- /dev/null +++ b/esmerald/routing/apis/_metaclasses.py @@ -0,0 +1,58 @@ +from typing import TYPE_CHECKING, Any, List, Set, Tuple, Type, cast + +if TYPE_CHECKING: + from esmerald import SimpleAPIView + + +class SimpleAPIMeta(type): + """ + Metaclass responsible for making sure + only the CRUD objects are allowed. + """ + + def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any: + """ + Making sure the `http_allowed_methods` are extended if inheritance happens + in the subclass. + + The `http_allowed_methods` is the default for each type of generic but to allow + extra allowed methods, the `extra_allowed` must be added. + """ + view = super().__new__ + + parents = [parent for parent in bases if isinstance(parent, SimpleAPIMeta)] + if not parents: + return view(cls, name, bases, attrs) + + simple_view = cast("SimpleAPIView", view(cls, name, bases, attrs)) + filtered_handlers: List[str] = [ + attr + for attr in dir(simple_view) + if not attr.startswith("__") and not attr.endswith("__") + ] + + for base in bases: + if ( + hasattr(base, "http_allowed_methods") + and hasattr(base, "__is_generic__") + and getattr(base, "__is_generic__", False) not in [False, None] + ): + simple_view.http_allowed_methods.extend(base.http_allowed_methods) + + if hasattr(simple_view, "extra_allowed"): + assert isinstance( + simple_view.extra_allowed, list + ), "`extra_allowed` must be a list of strings allowed." + + simple_view.http_allowed_methods.extend(simple_view.extra_allowed) + + allowed_methods: Set[str] = {method.lower() for method in simple_view.http_allowed_methods} + simple_view.http_allowed_methods = list(allowed_methods) + message = ", ".join(allowed_methods) + + for handler_name in filtered_handlers: + for base in simple_view.__bases__: + attribute = getattr(simple_view, handler_name) + simple_view.is_method_allowed(handler_name, base, attribute, message) + + return simple_view diff --git a/esmerald/routing/apis/views.py b/esmerald/routing/apis/views.py index 75d09e9b..6cb5689b 100644 --- a/esmerald/routing/apis/views.py +++ b/esmerald/routing/apis/views.py @@ -1,63 +1,10 @@ -from typing import Any, Callable, List, Set, Tuple, Type, Union, cast +from typing import Any, Callable, List, Union from esmerald.exceptions import ImproperlyConfigured +from esmerald.routing.apis._metaclasses import SimpleAPIMeta from esmerald.routing.apis.base import View -class SimpleAPIMeta(type): - """ - Metaclass responsible for making sure - only the CRUD objects are allowed. - """ - - def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any: - """ - Making sure the `http_allowed_methods` are extended if inheritance happens - in the subclass. - - The `http_allowed_methods` is the default for each type of generic but to allow - extra allowed methods, the `extra_allowed` must be added. - """ - view = super().__new__ - - parents = [parent for parent in bases if isinstance(parent, SimpleAPIMeta)] - if not parents: - return view(cls, name, bases, attrs) - - simple_view = cast("SimpleAPIView", view(cls, name, bases, attrs)) - filtered_handlers: List[str] = [ - attr - for attr in dir(simple_view) - if not attr.startswith("__") and not attr.endswith("__") - ] - - for base in bases: - if ( - hasattr(base, "http_allowed_methods") - and hasattr(base, "__is_generic__") - and getattr(base, "__is_generic__", False) not in [False, None] - ): - simple_view.http_allowed_methods.extend(base.http_allowed_methods) - - if hasattr(simple_view, "extra_allowed"): - assert isinstance( - simple_view.extra_allowed, list - ), "`extra_allowed` must be a list of strings allowed." - - simple_view.http_allowed_methods.extend(simple_view.extra_allowed) - - allowed_methods: Set[str] = {method.lower() for method in simple_view.http_allowed_methods} - simple_view.http_allowed_methods = list(allowed_methods) - message = ", ".join(allowed_methods) - - for handler_name in filtered_handlers: - for base in simple_view.__bases__: - attribute = getattr(simple_view, handler_name) - simple_view.is_method_allowed(handler_name, base, attribute, message) - - return simple_view - - class SimpleAPIView(View, metaclass=SimpleAPIMeta): """The Esmerald SimpleAPIView class. @@ -89,10 +36,16 @@ def is_method_allowed( method, (HTTPHandler, WebSocketHandler, WebhookHandler), ): - if name.lower() not in cls.http_allowed_methods: # type: ignore[unreachable] + if hasattr(cls, "extra_allowed") and name.lower() in cls.extra_allowed: # type: ignore[unreachable] + return True + if name.lower() not in cls.http_allowed_methods: raise ImproperlyConfigured( f"{cls.__name__} only allows functions with the name(s) `{error_message}` to be implemented, got `{name.lower()}` instead." ) + elif name.lower() != method.__class__.__name__.lower(): + raise ImproperlyConfigured( + f"The function '{name.lower()}' must implement the '{name.lower()}()' handler, got '{method.__class__.__name__.lower()}()' instead." + ) return True diff --git a/tests/routing/test_api_views.py b/tests/routing/test_api_views.py index 7d83ba40..9ac08222 100644 --- a/tests/routing/test_api_views.py +++ b/tests/routing/test_api_views.py @@ -167,3 +167,28 @@ def test_all_api_view_custom_error(test_client_factory, value): class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView): extra_allowed: List[str] = ("create_user", "read_item") + + +@pytest.mark.parametrize( + "value", [value for value in SimpleAPIView.http_allowed_methods if value != "get"] +) +def test_default_parameters_raise_error_on_wrong_handler(test_client_factory, value): + handler = getattr(esmerald, value) + + with pytest.raises(ImproperlyConfigured) as raised: + + class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView): + extra_allowed: List[str] = ["create_user"] + + @handler("/") + def get(self) -> None: + ... + + @handler("/") + def create_user() -> None: + ... + + assert ( + raised.value.detail + == f"The function 'get' must implement the 'get()' handler, got '{value}()' instead." + ) From 26af6f3a1c110aa9348ad50579510be64f61e1d6 Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 9 Oct 2023 11:35:11 +0100 Subject: [PATCH 3/4] Fix metaclass for SimpleAPIView --- esmerald/routing/apis/_metaclasses.py | 17 ++++++++++------- esmerald/routing/apis/views.py | 3 +++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/esmerald/routing/apis/_metaclasses.py b/esmerald/routing/apis/_metaclasses.py index 1eeca81f..2aacb813 100644 --- a/esmerald/routing/apis/_metaclasses.py +++ b/esmerald/routing/apis/_metaclasses.py @@ -24,6 +24,7 @@ def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any: if not parents: return view(cls, name, bases, attrs) + http_allowed_methods: List[str] = [] simple_view = cast("SimpleAPIView", view(cls, name, bases, attrs)) filtered_handlers: List[str] = [ attr @@ -37,22 +38,24 @@ def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any: and hasattr(base, "__is_generic__") and getattr(base, "__is_generic__", False) not in [False, None] ): - simple_view.http_allowed_methods.extend(base.http_allowed_methods) + http_allowed_methods.extend(base.http_allowed_methods) if hasattr(simple_view, "extra_allowed"): assert isinstance( simple_view.extra_allowed, list ), "`extra_allowed` must be a list of strings allowed." - simple_view.http_allowed_methods.extend(simple_view.extra_allowed) + http_allowed_methods.extend(simple_view.extra_allowed) - allowed_methods: Set[str] = {method.lower() for method in simple_view.http_allowed_methods} - simple_view.http_allowed_methods = list(allowed_methods) - message = ", ".join(allowed_methods) + http_allowed_methods.extend(simple_view.http_allowed_methods) + + # Remove any duplicates + allowed_methods: Set[str] = {method.lower() for method in http_allowed_methods} + # Reasign the new clean list + simple_view.http_allowed_methods = list(allowed_methods) for handler_name in filtered_handlers: for base in simple_view.__bases__: attribute = getattr(simple_view, handler_name) - simple_view.is_method_allowed(handler_name, base, attribute, message) - + simple_view.is_method_allowed(handler_name, base, attribute) return simple_view diff --git a/esmerald/routing/apis/views.py b/esmerald/routing/apis/views.py index 6cb5689b..a3fd92a9 100644 --- a/esmerald/routing/apis/views.py +++ b/esmerald/routing/apis/views.py @@ -39,6 +39,9 @@ def is_method_allowed( if hasattr(cls, "extra_allowed") and name.lower() in cls.extra_allowed: # type: ignore[unreachable] return True if name.lower() not in cls.http_allowed_methods: + if error_message is None: + error_message = ", ".join(cls.http_allowed_methods) + raise ImproperlyConfigured( f"{cls.__name__} only allows functions with the name(s) `{error_message}` to be implemented, got `{name.lower()}` instead." ) From 5f3454cb8c2b5aa71c24dc8fdecde2d30f8443ba Mon Sep 17 00:00:00 2001 From: tarsil Date: Mon, 9 Oct 2023 11:42:00 +0100 Subject: [PATCH 4/4] Update openapi tests for APIView --- tests/openapi/test_include_with_apiview.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/openapi/test_include_with_apiview.py b/tests/openapi/test_include_with_apiview.py index 8d497afb..3461acb6 100644 --- a/tests/openapi/test_include_with_apiview.py +++ b/tests/openapi/test_include_with_apiview.py @@ -21,7 +21,7 @@ def read_people() -> Dict[str, str]: def test_add_include_to_openapi(test_client_factory, value): class MyAPI(value): if issubclass(value, SimpleAPIView): - http_allowed_methods = ["read_item"] + extra_allowed = ["read_item"] @get( "/item", @@ -107,7 +107,7 @@ async def read_item(self) -> JSON: def test_include_no_include_in_schema(test_client_factory, value): class MyAPI(value): if issubclass(value, SimpleAPIView): - http_allowed_methods = ["read_item"] + extra_allowed = ["read_item"] @get( "/item",