Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V2 #184

Merged
merged 5 commits into from
Oct 9, 2023
Merged

V2 #184

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions docs/routing/apiview.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ allows the creation of apis where the function name can be whatever you desire l
So what does that mean? Means **you can only perform operations where the function name coincides with the http verb**.
For example, `get`, `put`, `post` etc...

If you attempt to create a functionm where the name differs from a http verb,
an `ImproperlyConfigured` exception is raised.
If you attempt to create a function where the name differs from a http verb,
an `ImproperlyConfigured` exception is raised **unless the `extra_allowed` is declared**.

The available http verbs are:

Expand All @@ -113,6 +113,27 @@ The available http verbs are:

Basically the same availability as the [handlers](./handlers.md).

### Important

The generics **enforce** the name matching of the functions with the handlers. That means, if
you use a `ReadAPIView` that only allows the `get` and you use the wrong [handlers](./handlers.md)
on the top of it, for example a [post](./handlers.md#post), an `ImproperlyConfigured` exception
will be raised.

Let us see what this means.

```python hl_lines="13-14"
{!> ../docs_src/routing/generics/important.py !}
```

As you can see, the handler `post()` does not match the function name `get`. **It should always match**.

An easy way of knowing this is simple, when it comes to the available http verbs, the function name
**should always match the handler**.

Are there any exception? Yes but not for these specific cases, the exceptions are called
[extra_allowed](#extra_allowed) but more details about this later on.

### SimpleAPIView

This is the base of all generics, subclassing from this class will allow you to perform all the
Expand Down Expand Up @@ -188,11 +209,12 @@ What if you want to combine them all? Of course you also can.

**Combining them all is the same as using the [SimpleAPIView](#simpleapiview)**.

### http_allowed_methods
### extra_allowed

All the generics subclass the [SimpleAPIView](#simpleapiview) as mentioned before and that superclass
uses the `http_allowed_methods` to verify which methods are allowed or not to be passed inside
the API object.
the API object but also check if there is any `extra_allowed` list with any extra functions you
would like the view to deliver.

This means that if you want to add a `read_item()` function to any of the
generics you also do it easily.
Expand All @@ -202,7 +224,7 @@ generics you also do it easily.
```

As you can see, to make it happen you would need to declare the function name inside the
`http_allowed_methods` to make sure that an `ImproperlyConfigured` is not raised.
`extra_allowed` to make sure that an `ImproperlyConfigured` is not raised.

## What to choose

Expand Down
2 changes: 1 addition & 1 deletion docs_src/routing/generics/allowed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class UserAPI(CreateAPIView):
to be used by default.
"""

http_allowed_methods: List[str] = ["read_item"]
extra_allowed: List[str] = ["read_item"]

@post()
def post(self) -> str:
Expand Down
15 changes: 15 additions & 0 deletions docs_src/routing/generics/important.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List

from esmerald import get, patch, post, put
from esmerald.routing.apis.generics import CreateAPIView


class UserAPI(CreateAPIView):
"""
ImproperlyConfigured will be raised as the handler `post()`
name does not match the function name `post`.
"""

@post()
def get(self) -> str:
...
61 changes: 61 additions & 0 deletions esmerald/routing/apis/_metaclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING, Any, List, Set, Tuple, Type, cast

if TYPE_CHECKING:
from esmerald import SimpleAPIView


class SimpleAPIMeta(type):
"""
Metaclass responsible for making sure
only the CRUD objects are allowed.
"""

def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any:
"""
Making sure the `http_allowed_methods` are extended if inheritance happens
in the subclass.

The `http_allowed_methods` is the default for each type of generic but to allow
extra allowed methods, the `extra_allowed` must be added.
"""
view = super().__new__

parents = [parent for parent in bases if isinstance(parent, SimpleAPIMeta)]
if not parents:
return view(cls, name, bases, attrs)

http_allowed_methods: List[str] = []
simple_view = cast("SimpleAPIView", view(cls, name, bases, attrs))
filtered_handlers: List[str] = [
attr
for attr in dir(simple_view)
if not attr.startswith("__") and not attr.endswith("__")
]

for base in bases:
if (
hasattr(base, "http_allowed_methods")
and hasattr(base, "__is_generic__")
and getattr(base, "__is_generic__", False) not in [False, None]
):
http_allowed_methods.extend(base.http_allowed_methods)

if hasattr(simple_view, "extra_allowed"):
assert isinstance(
simple_view.extra_allowed, list
), "`extra_allowed` must be a list of strings allowed."

http_allowed_methods.extend(simple_view.extra_allowed)

http_allowed_methods.extend(simple_view.http_allowed_methods)

# Remove any duplicates
allowed_methods: Set[str] = {method.lower() for method in http_allowed_methods}

# Reasign the new clean list
simple_view.http_allowed_methods = list(allowed_methods)
for handler_name in filtered_handlers:
for base in simple_view.__bases__:
attribute = getattr(simple_view, handler_name)
simple_view.is_method_allowed(handler_name, base, attribute)
return simple_view
58 changes: 12 additions & 46 deletions esmerald/routing/apis/views.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,10 @@
from typing import Any, Callable, List, Set, Tuple, Type, Union, cast
from typing import Any, Callable, List, Union

from esmerald.exceptions import ImproperlyConfigured
from esmerald.routing.apis._metaclasses import SimpleAPIMeta
from esmerald.routing.apis.base import View


class SimpleAPIMeta(type):
"""
Metaclass responsible for making sure
only the CRUD objects are allowed.
"""

def __new__(cls, name: str, bases: Tuple[Type, ...], attrs: Any) -> Any:
"""
Making sure the `http_allowed_methods` are extended if inheritance happens
in the subclass
"""
view = super().__new__

parents = [parent for parent in bases if isinstance(parent, SimpleAPIMeta)]
if not parents:
return view(cls, name, bases, attrs)

simple_view = cast("SimpleAPIView", view(cls, name, bases, attrs))
filtered_handlers: List[str] = [
attr
for attr in dir(simple_view)
if not attr.startswith("__") and not attr.endswith("__")
]

for base in bases:
if (
hasattr(base, "http_allowed_methods")
and hasattr(base, "__is_generic__")
and getattr(base, "__is_generic__", False) not in [False, None]
):
simple_view.http_allowed_methods.extend(base.http_allowed_methods)

allowed_methods: Set[str] = {method.lower() for method in simple_view.http_allowed_methods}
simple_view.http_allowed_methods = list(allowed_methods)
message = ", ".join(allowed_methods)

for handler_name in filtered_handlers:
for base in simple_view.__bases__:
attribute = getattr(simple_view, handler_name)
simple_view.is_method_allowed(handler_name, base, attribute, message)

return simple_view


class SimpleAPIView(View, metaclass=SimpleAPIMeta):
"""The Esmerald SimpleAPIView class.

Expand Down Expand Up @@ -79,10 +36,19 @@ def is_method_allowed(
method,
(HTTPHandler, WebSocketHandler, WebhookHandler),
):
if name.lower() not in cls.http_allowed_methods: # type: ignore[unreachable]
if hasattr(cls, "extra_allowed") and name.lower() in cls.extra_allowed: # type: ignore[unreachable]
return True
if name.lower() not in cls.http_allowed_methods:
if error_message is None:
error_message = ", ".join(cls.http_allowed_methods)

raise ImproperlyConfigured(
f"{cls.__name__} only allows functions with the name(s) `{error_message}` to be implemented, got `{name.lower()}` instead."
)
elif name.lower() != method.__class__.__name__.lower():
raise ImproperlyConfigured(
f"The function '{name.lower()}' must implement the '{name.lower()}()' handler, got '{method.__class__.__name__.lower()}()' instead."
)
return True


Expand Down
17 changes: 3 additions & 14 deletions esmerald/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
import functools
import sys
import typing
from inspect import isclass
from typing import Any, Awaitable, Callable, TypeVar, Union
from typing import Any, Union

import slugify
from typing_extensions import ParamSpec, TypeGuard, get_args, get_origin
from starlette._utils import is_async_callable as is_async_callable
from typing_extensions import get_args, get_origin

if sys.version_info >= (3, 10):
from types import UnionType
Expand All @@ -15,16 +14,6 @@
else: # pragma: no cover
UNION_TYPES = {Union}

P = ParamSpec("P")
T = TypeVar("T")


def is_async_callable(value: Callable[P, T]) -> TypeGuard[Callable[P, Awaitable[T]]]:
while isinstance(value, functools.partial):
value = value.func # type: ignore[unreachable]

return asyncio.iscoroutinefunction(value) or asyncio.iscoroutinefunction(value.__call__) # type: ignore


def is_class_and_subclass(value: typing.Any, _type: typing.Any) -> bool:
original = get_origin(value)
Expand Down
4 changes: 2 additions & 2 deletions tests/openapi/test_include_with_apiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def read_people() -> Dict[str, str]:
def test_add_include_to_openapi(test_client_factory, value):
class MyAPI(value):
if issubclass(value, SimpleAPIView):
http_allowed_methods = ["read_item"]
extra_allowed = ["read_item"]

@get(
"/item",
Expand Down Expand Up @@ -107,7 +107,7 @@ async def read_item(self) -> JSON:
def test_include_no_include_in_schema(test_client_factory, value):
class MyAPI(value):
if issubclass(value, SimpleAPIView):
http_allowed_methods = ["read_item"]
extra_allowed = ["read_item"]

@get(
"/item",
Expand Down
39 changes: 38 additions & 1 deletion tests/routing/test_api_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def get(self) -> str:
@pytest.mark.parametrize("value,method", [("create_user", "post"), ("read_item", "get")])
def test_all_api_view_custom(test_client_factory, value, method):
class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView):
http_allowed_methods: List[str] = ["create_user", "read_item"]
extra_allowed: List[str] = ["create_user", "read_item"]

@post(status_code=200)
async def create_user(self) -> str:
Expand All @@ -155,3 +155,40 @@ async def read_item(self) -> str:
response = getattr(client, method)("/")
assert response.status_code == 200
assert response.json() == f"home {value}"


@pytest.mark.parametrize(
"value",
[("create_user",), {"create_user"}, {"name": "create_user"}],
ids=["tuple", "set", "dict"],
)
def test_all_api_view_custom_error(test_client_factory, value):
with pytest.raises(AssertionError):

class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView):
extra_allowed: List[str] = ("create_user", "read_item")


@pytest.mark.parametrize(
"value", [value for value in SimpleAPIView.http_allowed_methods if value != "get"]
)
def test_default_parameters_raise_error_on_wrong_handler(test_client_factory, value):
handler = getattr(esmerald, value)

with pytest.raises(ImproperlyConfigured) as raised:

class GenericAPIView(CreateAPIView, ReadAPIView, DeleteAPIView):
extra_allowed: List[str] = ["create_user"]

@handler("/")
def get(self) -> None:
...

@handler("/")
def create_user() -> None:
...

assert (
raised.value.detail
== f"The function 'get' must implement the 'get()' handler, got '{value}()' instead."
)