From f6c2afc96dd43b36cc86cd594221eef84aafacfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Piku=C5=82a?= Date: Tue, 18 Apr 2023 18:36:07 +0200 Subject: [PATCH] Improve typing in the module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #65 - make it pass mypy tests (i.e., improve type hinting) - add mypy test to CI - add py.typed to setup so that other packages recognize it's typed - enable async view function decoration (via ensure_sync() as noted on https://flask.palletsprojects.com/en/latest/async-await/#extensions) Signed-off-by: Marek PikuĊ‚a --- .github/workflows/tests.yml | 20 +++ flask_pydantic/converters.py | 8 +- flask_pydantic/core.py | 245 +++++++++++++++++++++-------------- flask_pydantic/exceptions.py | 32 +++-- flask_pydantic/py.typed | 0 requirements/test.pip | 2 + setup.py | 1 + 7 files changed, 192 insertions(+), 116 deletions(-) create mode 100644 flask_pydantic/py.typed diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aea03de..c165eee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,6 +3,26 @@ name: Tests on: [push, pull_request] jobs: + mypy: + runs-on: ubuntu-latest + strategy: + fail-fast: false + + steps: + - uses: actions/checkout@v1 + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: "3.7" + - name: Install dependencies + if: steps.cache-pip.outputs.cache-hit != 'true' + run: | + python -m pip install --upgrade pip + pip install -r requirements/test.pip + - name: Run mypy satic check + run: | + python3 -m mypy flask_pydantic/ + build: runs-on: ${{ matrix.os }} strategy: diff --git a/flask_pydantic/converters.py b/flask_pydantic/converters.py index 0fed087..f128631 100644 --- a/flask_pydantic/converters.py +++ b/flask_pydantic/converters.py @@ -1,12 +1,12 @@ -from typing import Type +from typing import Dict, List, Type, Union from pydantic import BaseModel -from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.datastructures import MultiDict def convert_query_params( - query_params: ImmutableMultiDict, model: Type[BaseModel] -) -> dict: + query_params: "MultiDict[str, str]", model: Type[BaseModel] +) -> Dict[str, Union[str, List[str]]]: """ group query parameters into lists if model defines them diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 912415a..294f36a 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -1,7 +1,28 @@ +from collections.abc import Iterable from functools import wraps -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) + +from flask import Response, current_app, jsonify, request +from flask.typing import ResponseReturnValue, RouteCallable + +try: + from flask_restful import ( # type: ignore + original_flask_make_response as make_response, + ) +except ImportError: + from flask import make_response -from flask import Response, current_app, jsonify, make_response, request from pydantic import BaseModel, ValidationError from pydantic.tools import parse_obj_as @@ -13,22 +34,33 @@ ) from .exceptions import ValidationError as FailedValidation -try: - from flask_restful import original_flask_make_response as make_response -except ImportError: - pass +if TYPE_CHECKING: + from pydantic.error_wrappers import ErrorDict + + +ModelResponseReturnValue = Union[ResponseReturnValue, BaseModel] +ModelRouteCallable = Union[ + Callable[..., ModelResponseReturnValue], + Callable[..., Awaitable[ModelResponseReturnValue]], +] def make_json_response( - content: Union[BaseModel, Iterable[BaseModel]], + content: "Union[BaseModel, Iterable[BaseModel]]", status_code: int, by_alias: bool, exclude_none: bool = False, - many: bool = False, ) -> Response: """serializes model, creates JSON response with given status code""" - if many: - js = f"[{', '.join([model.json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]" + if not isinstance(content, BaseModel): + js = "[" + js += ", ".join( + [ + model.json(exclude_none=exclude_none, by_alias=by_alias) + for model in content + ] + ) + js += "]" else: js = content.json(exclude_none=exclude_none, by_alias=by_alias) response = make_response(js, status_code) @@ -56,9 +88,9 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel return [model(**fields) for fields in content] except TypeError: # iteration through `content` fails - err = [ + err: List["ErrorDict"] = [ { - "loc": ["root"], + "loc": ("root",), "msg": "is not an array of objects", "type": "type_error.array", } @@ -68,8 +100,10 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel raise ManyModelValidationError(ve.errors()) -def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: - errors = [] +def validate_path_params( + func: ModelRouteCallable, kwargs: Dict[str, Any] +) -> Tuple[Dict[str, Any], List["ErrorDict"]]: + errors: List["ErrorDict"] = [] validated = {} for name, type_ in func.__annotations__.items(): if name in {"query", "body", "form", "return"}: @@ -77,21 +111,42 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: try: value = parse_obj_as(type_, kwargs.get(name)) validated[name] = value - except ValidationError as e: - err = e.errors()[0] - err["loc"] = [name] + except ValidationError as error: + err = error.errors()[0] + err["loc"] = (name,) errors.append(err) kwargs = {**kwargs, **validated} return kwargs, errors -def get_body_dict(**params): - data = request.get_json(**params) +def get_body_dict(**params: Dict[str, Any]) -> Any: + data = request.get_json(**params) # type: ignore if data is None and params.get("silent"): return {} return data +def _ensure_model_kwarg( + kwarg_name: str, + from_validate: Optional[Type[BaseModel]], + func: ModelRouteCallable, +) -> Tuple[Optional[Type[BaseModel]], bool]: + """Get model information either from wrapped function or validate kwargs.""" + in_func_kwargs = func.__annotations__.get(kwarg_name) + if in_func_kwargs is None: + return from_validate, False + assert isinstance(in_func_kwargs, type) and issubclass( + in_func_kwargs, BaseModel + ), "Model in function arguments needs to be a BaseModel." + + # Ensure that the most "detailed" model is used. + if from_validate is None: + return in_func_kwargs, True + if issubclass(in_func_kwargs, from_validate): + return in_func_kwargs, True + return from_validate, True + + def validate( body: Optional[Type[BaseModel]] = None, query: Optional[Type[BaseModel]] = None, @@ -100,7 +155,7 @@ def validate( response_many: bool = False, request_body_many: bool = False, response_by_alias: bool = False, - get_json_params: Optional[dict] = None, + get_json_params: Optional[Dict[str, Any]] = None, form: Optional[Type[BaseModel]] = None, ): """ @@ -163,105 +218,93 @@ def test_route_kwargs(query:Query, body:Body, form:Form): -> that will render JSON response with serialized MyModel instance """ - def decorate(func: Callable) -> Callable: + def decorate(func: ModelRouteCallable) -> RouteCallable: @wraps(func) - def wrapper(*args, **kwargs): - q, b, f, err = None, None, None, {} - kwargs, path_err = validate_path_params(func, kwargs) - if path_err: - err["path_params"] = path_err - query_in_kwargs = func.__annotations__.get("query") - query_model = query_in_kwargs or query - if query_model: + def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> ResponseReturnValue: + q, b, f, err = None, None, None, FailedValidation() + func_kwargs, path_err = validate_path_params(func, kwargs) + if len(path_err) > 0: + err.path_params = path_err + query_model, query_in_kwargs = _ensure_model_kwarg("query", query, func) + if query_model is not None: query_params = convert_query_params(request.args, query_model) try: q = query_model(**query_params) except ValidationError as ve: - err["query_params"] = ve.errors() - body_in_kwargs = func.__annotations__.get("body") - body_model = body_in_kwargs or body - if body_model: + err.query_params = ve.errors() + body_model, body_in_kwargs = _ensure_model_kwarg("body", body, func) + if body_model is not None: body_params = get_body_dict(**(get_json_params or {})) - if "__root__" in body_model.__fields__: - try: - b = body_model(__root__=body_params).__root__ - except ValidationError as ve: - err["body_params"] = ve.errors() - elif request_body_many: - try: + try: + if "__root__" in body_model.__fields__: + b = body_model(__root__=body_params).__root__ # type: ignore + elif request_body_many: b = validate_many_models(body_model, body_params) - except ManyModelValidationError as e: - err["body_params"] = e.errors() - else: - try: + else: b = body_model(**body_params) - except TypeError: - content_type = request.headers.get("Content-Type", "").lower() - media_type = content_type.split(";")[0] - if media_type != "application/json": - return unsupported_media_type_response(content_type) - else: - raise JsonBodyParsingError() - except ValidationError as ve: - err["body_params"] = ve.errors() - form_in_kwargs = func.__annotations__.get("form") - form_model = form_in_kwargs or form - if form_model: + except (ValidationError, ManyModelValidationError) as error: + err.body_params = error.errors() + except TypeError as error: + content_type = request.headers.get("Content-Type", "").lower() + media_type = content_type.split(";")[0] + if media_type != "application/json": + return unsupported_media_type_response(content_type) + else: + raise JsonBodyParsingError() from error + form_model, form_in_kwargs = _ensure_model_kwarg("form", form, func) + if form_model is not None: form_params = request.form - if "__root__" in form_model.__fields__: - try: - f = form_model(__root__=form_params).__root__ - except ValidationError as ve: - err["form_params"] = ve.errors() - else: - try: + try: + if "__root__" in form_model.__fields__: + f = form_model(__root__=form_params).__root__ # type: ignore + else: f = form_model(**form_params) - except TypeError: - content_type = request.headers.get("Content-Type", "").lower() - media_type = content_type.split(";")[0] - if media_type != "multipart/form-data": - return unsupported_media_type_response(content_type) - else: - raise JsonBodyParsingError - except ValidationError as ve: - err["form_params"] = ve.errors() - request.query_params = q - request.body_params = b - request.form_params = f + except TypeError as error: + content_type = request.headers.get("Content-Type", "").lower() + media_type = content_type.split(";")[0] + if media_type != "multipart/form-data": + return unsupported_media_type_response(content_type) + else: + raise JsonBodyParsingError() from error + except ValidationError as ve: + err.form_params = ve.errors() + request.query_params = q # type: ignore + request.body_params = b # type: ignore + request.form_params = f # type: ignore if query_in_kwargs: - kwargs["query"] = q + func_kwargs["query"] = q if body_in_kwargs: - kwargs["body"] = b + func_kwargs["body"] = b if form_in_kwargs: - kwargs["form"] = f + func_kwargs["form"] = f - if err: + if err.check(): if current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_RAISE", False ): - raise FailedValidation(**err) + raise err else: status_code = current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400 ) return make_response( - jsonify({"validation_error": err}), - status_code + jsonify({"validation_error": err.to_dict()}), status_code ) - res = func(*args, **kwargs) + res: ModelResponseReturnValue = current_app.ensure_sync(func)( + *args, **func_kwargs + ) if response_many: - if is_iterable_of_models(res): - return make_json_response( - res, - on_success_status, - by_alias=response_by_alias, - exclude_none=exclude_none, - many=True, - ) - else: + if not is_iterable_of_models(res): raise InvalidIterableOfModelsException(res) + return make_json_response( + res, # type: ignore # Iterability and type is ensured above. + on_success_status, + by_alias=response_by_alias, + exclude_none=exclude_none, + ) + if isinstance(res, BaseModel): return make_json_response( res, @@ -275,23 +318,29 @@ def wrapper(*args, **kwargs): and len(res) in [2, 3] and isinstance(res[0], BaseModel) ): - headers = None + headers: Optional[ + Union[Dict[str, Any], Tuple[Any, ...], List[Any]] + ] = None status = on_success_status if isinstance(res[1], (dict, tuple, list)): headers = res[1] - elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)): - status = res[1] - headers = res[2] - else: + elif isinstance(res[1], int): status = res[1] + # Following type ignores should be fixed once + # https://github.com/python/mypy/issues/1178 is fixed. + if len(res) == 3 and isinstance( + res[2], (dict, tuple, list) # type: ignore[misc] + ): + headers = res[2] # type: ignore[misc] + ret = make_json_response( res[0], status, exclude_none=exclude_none, by_alias=response_by_alias, ) - if headers: + if headers is not None: ret.headers.update(headers) return ret diff --git a/flask_pydantic/exceptions.py b/flask_pydantic/exceptions.py index e214cc1..14eb7ba 100644 --- a/flask_pydantic/exceptions.py +++ b/flask_pydantic/exceptions.py @@ -1,4 +1,8 @@ -from typing import List, Optional +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, List, Optional + +if TYPE_CHECKING: + from pydantic.error_wrappers import ErrorDict class BaseFlaskPydanticException(Exception): @@ -24,7 +28,7 @@ class ManyModelValidationError(BaseFlaskPydanticException): """This exception is raised if there is a failure during validation of many models in an iterable""" - def __init__(self, errors: List[dict], *args): + def __init__(self, errors: List["ErrorDict"], *args: Any): self._errors = errors super().__init__(*args) @@ -32,19 +36,19 @@ def errors(self): return self._errors +@dataclass class ValidationError(BaseFlaskPydanticException): """This exception is raised if there is a failure during validation if the user has configured an exception to be raised instead of a response""" - def __init__( - self, - body_params: Optional[List[dict]] = None, - form_params: Optional[List[dict]] = None, - path_params: Optional[List[dict]] = None, - query_params: Optional[List[dict]] = None, - ): - super().__init__() - self.body_params = body_params - self.form_params = form_params - self.path_params = path_params - self.query_params = query_params + body_params: Optional[List["ErrorDict"]] = None + form_params: Optional[List["ErrorDict"]] = None + path_params: Optional[List["ErrorDict"]] = None + query_params: Optional[List["ErrorDict"]] = None + + def check(self) -> bool: + """Check if any param resulted in error.""" + return any(value is not None for _, value in asdict(self).items()) + + def to_dict(self): + return {key: value for key, value in asdict(self).items() if value is not None} diff --git a/flask_pydantic/py.typed b/flask_pydantic/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/requirements/test.pip b/requirements/test.pip index 14530b2..a959c83 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -6,3 +6,5 @@ pytest-flake8 pytest-coverage pytest-black pytest-mock + +mypy diff --git a/setup.py b/setup.py index 09be0eb..1a4c595 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ def find_version(file_path: Path = VERSION_FILE_PATH) -> str: long_description=README, long_description_content_type="text/markdown", packages=["flask_pydantic"], + package_data={"flask_pydantic": ["py.typed"]}, install_requires=list(get_install_requires()), python_requires=">=3.6", classifiers=[