diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aea03de..cda6db3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,28 +3,56 @@ name: Tests on: [push, pull_request] jobs: + lint: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + tool: + - flake8 + - mypy + - black --diff --check + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.7 + uses: actions/setup-python@v4 + with: + python-version: "3.7" + cache: pip + cache-dependency-path: | + requirements/base.pip + requirements/test.pip + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/test.pip + - name: Run ${{ matrix.tool }} static check + run: | + python3 -m ${{ matrix.tool }} flask_pydantic/ tests/ + build: runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] os: [ubuntu-latest, macOS-latest] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: | + requirements/base.pip + requirements/test.pip - name: Install dependencies - if: steps.cache-pip.outputs.cache-hit != 'true' run: | python -m pip install --upgrade pip pip install -r requirements/test.pip - name: Test with pytest run: | python3 -m pytest - - name: Lint with flake8 - run: | - flake8 . diff --git a/README.md b/README.md index cb585b8..e6c2204 100644 --- a/README.md +++ b/README.md @@ -297,6 +297,44 @@ def my_route(): return MyModel(...) ``` +### Generic model type + +It's possible to use a validator on a generic model type. It can be useful, +e.g., when making a convenience API passthrough (example taken from +[headscale-webui](https://github.com/iFargle/headscale-webui/blob/cc67cf0e9cf0b33fd66d766ce74d86e0c9cda114/server.py#L186) +project): + +```python +RequestT = TypeVar("RequestT", bound=Message) +ResponseT = TypeVar("ResponseT", bound=Message) + +def api_passthrough( + route: str, + request_type: Type[RequestT], + api_method: Callable[[RequestT], ResponseT | str], +): + def api_passthrough_page(body: RequestT) -> ResponseT | str: + return api_method(body) + + api_passthrough_page.__name__ = route.replace("/", "_") + api_passthrough_page.__annotations__ = {"body": request_type} + + return app.route(route, methods=["POST"])( + validate()(api_passthrough_page) + ) + +api_passthrough( + "/api/endpoint1", + schema.GetMachineRequest, + backend_api_endpoint1, +) +api_passthrough( + "/api/endpoint2", + schema.DeleteMachineRequest, + backend_api_endpoint2, +) +``` + ### Example app For more complete examples see [example application](https://github.com/bauerji/flask_pydantic/tree/master/example_app). diff --git a/flask_pydantic/converters.py b/flask_pydantic/converters.py index 0fed087..f128631 100644 --- a/flask_pydantic/converters.py +++ b/flask_pydantic/converters.py @@ -1,12 +1,12 @@ -from typing import Type +from typing import Dict, List, Type, Union from pydantic import BaseModel -from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.datastructures import MultiDict def convert_query_params( - query_params: ImmutableMultiDict, model: Type[BaseModel] -) -> dict: + query_params: "MultiDict[str, str]", model: Type[BaseModel] +) -> Dict[str, Union[str, List[str]]]: """ group query parameters into lists if model defines them diff --git a/flask_pydantic/core.py b/flask_pydantic/core.py index 912415a..d6ded8c 100644 --- a/flask_pydantic/core.py +++ b/flask_pydantic/core.py @@ -1,7 +1,31 @@ +import inspect +from collections.abc import Iterable from functools import wraps -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from flask import Response, current_app, jsonify, request +from flask.typing import ResponseReturnValue, RouteCallable +from typing_extensions import get_args + +try: + from flask_restful import ( # type: ignore + original_flask_make_response as make_response, + ) +except ImportError: + from flask import make_response -from flask import Response, current_app, jsonify, make_response, request from pydantic import BaseModel, ValidationError from pydantic.tools import parse_obj_as @@ -13,22 +37,33 @@ ) from .exceptions import ValidationError as FailedValidation -try: - from flask_restful import original_flask_make_response as make_response -except ImportError: - pass +if TYPE_CHECKING: + from pydantic.error_wrappers import ErrorDict + + +ModelResponseReturnValue = Union[ResponseReturnValue, BaseModel] +ModelRouteCallable = Union[ + Callable[..., ModelResponseReturnValue], + Callable[..., Awaitable[ModelResponseReturnValue]], +] def make_json_response( - content: Union[BaseModel, Iterable[BaseModel]], + content: "Union[BaseModel, Iterable[BaseModel]]", status_code: int, by_alias: bool, exclude_none: bool = False, - many: bool = False, ) -> Response: """serializes model, creates JSON response with given status code""" - if many: - js = f"[{', '.join([model.json(exclude_none=exclude_none, by_alias=by_alias) for model in content])}]" + if not isinstance(content, BaseModel): + js = "[" + js += ", ".join( + [ + model.json(exclude_none=exclude_none, by_alias=by_alias) + for model in content + ] + ) + js += "]" else: js = content.json(exclude_none=exclude_none, by_alias=by_alias) response = make_response(js, status_code) @@ -56,9 +91,9 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel return [model(**fields) for fields in content] except TypeError: # iteration through `content` fails - err = [ + err: List["ErrorDict"] = [ { - "loc": ["root"], + "loc": ("root",), "msg": "is not an array of objects", "type": "type_error.array", } @@ -68,8 +103,10 @@ def validate_many_models(model: Type[BaseModel], content: Any) -> List[BaseModel raise ManyModelValidationError(ve.errors()) -def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: - errors = [] +def validate_path_params( + func: ModelRouteCallable, kwargs: Dict[str, Any] +) -> Tuple[Dict[str, Any], List["ErrorDict"]]: + errors: List["ErrorDict"] = [] validated = {} for name, type_ in func.__annotations__.items(): if name in {"query", "body", "form", "return"}: @@ -77,21 +114,74 @@ def validate_path_params(func: Callable, kwargs: dict) -> Tuple[dict, list]: try: value = parse_obj_as(type_, kwargs.get(name)) validated[name] = value - except ValidationError as e: - err = e.errors()[0] - err["loc"] = [name] + except ValidationError as error: + err = error.errors()[0] + err["loc"] = (name,) errors.append(err) kwargs = {**kwargs, **validated} return kwargs, errors -def get_body_dict(**params): - data = request.get_json(**params) +def get_body_dict(**params: Dict[str, Any]) -> Any: + data = request.get_json(**params) # type: ignore if data is None and params.get("silent"): return {} return data +def _get_type_generic(hint: Any): + """Extract type information from bound TypeVar or Type[TypeVar]. + + Examples: + + ``` + MyGeneric = TypeVar("MyGeneric", bound=str) + MyGenericType = Type[MyGeneric] + + _get_type_generic(str) -> str + _get_type_generic(MyGeneric) -> str + _get_type_generic(MyGenericType) -> str + ``` + """ + # First check for direct TypeVar hint. + if isinstance(hint, TypeVar): + assert ( + getattr(hint, "__bound__", None) is not None + ), "If using TypeVar, you need to specify bound model." + return getattr(hint, "__bound__") + + # Check for Type[TypeVar] hint. + args = get_args(hint) + if len(args) > 0 and isinstance(args[0], TypeVar): + assert ( + getattr(args[0], "__bound__", None) is not None + ), "If using TypeVar, you need to specify bound model." + return getattr(args[0], "__bound__") + + # Otherwise, use the type hint directly. + return hint + + +def _ensure_model_kwarg( + kwarg_name: str, + from_validate: Optional[Type[BaseModel]], + func: ModelRouteCallable, +) -> Tuple[Optional[Type[BaseModel]], bool]: + """Get model information either from wrapped function or validate kwargs.""" + func_spec = inspect.getfullargspec(func) + in_func_arg = kwarg_name in func_spec.args or kwarg_name in func_spec.kwonlyargs + from_func = _get_type_generic(func_spec.annotations.get(kwarg_name)) + if from_func is None or not isinstance(from_func, type): + return _get_type_generic(from_validate), in_func_arg + + # Ensure that the most "detailed" model is used. + if from_validate is None: + return from_func, in_func_arg + if issubclass(from_func, from_validate): + return from_func, in_func_arg + return from_validate, in_func_arg + + def validate( body: Optional[Type[BaseModel]] = None, query: Optional[Type[BaseModel]] = None, @@ -100,7 +190,7 @@ def validate( response_many: bool = False, request_body_many: bool = False, response_by_alias: bool = False, - get_json_params: Optional[dict] = None, + get_json_params: Optional[Dict[str, Any]] = None, form: Optional[Type[BaseModel]] = None, ): """ @@ -163,105 +253,93 @@ def test_route_kwargs(query:Query, body:Body, form:Form): -> that will render JSON response with serialized MyModel instance """ - def decorate(func: Callable) -> Callable: + def decorate(func: ModelRouteCallable) -> RouteCallable: @wraps(func) - def wrapper(*args, **kwargs): - q, b, f, err = None, None, None, {} - kwargs, path_err = validate_path_params(func, kwargs) - if path_err: - err["path_params"] = path_err - query_in_kwargs = func.__annotations__.get("query") - query_model = query_in_kwargs or query - if query_model: + def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> ResponseReturnValue: + q, b, f, err = None, None, None, FailedValidation() + func_kwargs, path_err = validate_path_params(func, kwargs) + if len(path_err) > 0: + err.path_params = path_err + query_model, query_in_kwargs = _ensure_model_kwarg("query", query, func) + if query_model is not None: query_params = convert_query_params(request.args, query_model) try: q = query_model(**query_params) except ValidationError as ve: - err["query_params"] = ve.errors() - body_in_kwargs = func.__annotations__.get("body") - body_model = body_in_kwargs or body - if body_model: + err.query_params = ve.errors() + body_model, body_in_kwargs = _ensure_model_kwarg("body", body, func) + if body_model is not None: body_params = get_body_dict(**(get_json_params or {})) - if "__root__" in body_model.__fields__: - try: - b = body_model(__root__=body_params).__root__ - except ValidationError as ve: - err["body_params"] = ve.errors() - elif request_body_many: - try: + try: + if "__root__" in body_model.__fields__: + b = body_model(__root__=body_params).__root__ # type: ignore + elif request_body_many: b = validate_many_models(body_model, body_params) - except ManyModelValidationError as e: - err["body_params"] = e.errors() - else: - try: + else: b = body_model(**body_params) - except TypeError: - content_type = request.headers.get("Content-Type", "").lower() - media_type = content_type.split(";")[0] - if media_type != "application/json": - return unsupported_media_type_response(content_type) - else: - raise JsonBodyParsingError() - except ValidationError as ve: - err["body_params"] = ve.errors() - form_in_kwargs = func.__annotations__.get("form") - form_model = form_in_kwargs or form - if form_model: + except (ValidationError, ManyModelValidationError) as error: + err.body_params = error.errors() + except TypeError as error: + content_type = request.headers.get("Content-Type", "").lower() + media_type = content_type.split(";")[0] + if media_type != "application/json": + return unsupported_media_type_response(content_type) + else: + raise JsonBodyParsingError() from error + form_model, form_in_kwargs = _ensure_model_kwarg("form", form, func) + if form_model is not None: form_params = request.form - if "__root__" in form_model.__fields__: - try: - f = form_model(__root__=form_params).__root__ - except ValidationError as ve: - err["form_params"] = ve.errors() - else: - try: + try: + if "__root__" in form_model.__fields__: + f = form_model(__root__=form_params).__root__ # type: ignore + else: f = form_model(**form_params) - except TypeError: - content_type = request.headers.get("Content-Type", "").lower() - media_type = content_type.split(";")[0] - if media_type != "multipart/form-data": - return unsupported_media_type_response(content_type) - else: - raise JsonBodyParsingError - except ValidationError as ve: - err["form_params"] = ve.errors() - request.query_params = q - request.body_params = b - request.form_params = f + except TypeError as error: + content_type = request.headers.get("Content-Type", "").lower() + media_type = content_type.split(";")[0] + if media_type != "multipart/form-data": + return unsupported_media_type_response(content_type) + else: + raise JsonBodyParsingError() from error + except ValidationError as ve: + err.form_params = ve.errors() + request.query_params = q # type: ignore + request.body_params = b # type: ignore + request.form_params = f # type: ignore if query_in_kwargs: - kwargs["query"] = q + func_kwargs["query"] = q if body_in_kwargs: - kwargs["body"] = b + func_kwargs["body"] = b if form_in_kwargs: - kwargs["form"] = f + func_kwargs["form"] = f - if err: + if err.check(): if current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_RAISE", False ): - raise FailedValidation(**err) + raise err else: status_code = current_app.config.get( "FLASK_PYDANTIC_VALIDATION_ERROR_STATUS_CODE", 400 ) return make_response( - jsonify({"validation_error": err}), - status_code + jsonify({"validation_error": err.to_dict()}), status_code ) - res = func(*args, **kwargs) + res: ModelResponseReturnValue = current_app.ensure_sync(func)( + *args, **func_kwargs + ) if response_many: - if is_iterable_of_models(res): - return make_json_response( - res, - on_success_status, - by_alias=response_by_alias, - exclude_none=exclude_none, - many=True, - ) - else: + if not is_iterable_of_models(res): raise InvalidIterableOfModelsException(res) + return make_json_response( + res, # type: ignore # Iterability and type is ensured above. + on_success_status, + by_alias=response_by_alias, + exclude_none=exclude_none, + ) + if isinstance(res, BaseModel): return make_json_response( res, @@ -275,23 +353,29 @@ def wrapper(*args, **kwargs): and len(res) in [2, 3] and isinstance(res[0], BaseModel) ): - headers = None + headers: Optional[ + Union[Dict[str, Any], Tuple[Any, ...], List[Any]] + ] = None status = on_success_status if isinstance(res[1], (dict, tuple, list)): headers = res[1] - elif len(res) == 3 and isinstance(res[2], (dict, tuple, list)): - status = res[1] - headers = res[2] - else: + elif isinstance(res[1], int): status = res[1] + # Following type ignores should be fixed once + # https://github.com/python/mypy/issues/1178 is fixed. + if len(res) == 3 and isinstance( + res[2], (dict, tuple, list) # type: ignore[misc] + ): + headers = res[2] # type: ignore[misc] + ret = make_json_response( res[0], status, exclude_none=exclude_none, by_alias=response_by_alias, ) - if headers: + if headers is not None: ret.headers.update(headers) return ret diff --git a/flask_pydantic/exceptions.py b/flask_pydantic/exceptions.py index e214cc1..14eb7ba 100644 --- a/flask_pydantic/exceptions.py +++ b/flask_pydantic/exceptions.py @@ -1,4 +1,8 @@ -from typing import List, Optional +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any, List, Optional + +if TYPE_CHECKING: + from pydantic.error_wrappers import ErrorDict class BaseFlaskPydanticException(Exception): @@ -24,7 +28,7 @@ class ManyModelValidationError(BaseFlaskPydanticException): """This exception is raised if there is a failure during validation of many models in an iterable""" - def __init__(self, errors: List[dict], *args): + def __init__(self, errors: List["ErrorDict"], *args: Any): self._errors = errors super().__init__(*args) @@ -32,19 +36,19 @@ def errors(self): return self._errors +@dataclass class ValidationError(BaseFlaskPydanticException): """This exception is raised if there is a failure during validation if the user has configured an exception to be raised instead of a response""" - def __init__( - self, - body_params: Optional[List[dict]] = None, - form_params: Optional[List[dict]] = None, - path_params: Optional[List[dict]] = None, - query_params: Optional[List[dict]] = None, - ): - super().__init__() - self.body_params = body_params - self.form_params = form_params - self.path_params = path_params - self.query_params = query_params + body_params: Optional[List["ErrorDict"]] = None + form_params: Optional[List["ErrorDict"]] = None + path_params: Optional[List["ErrorDict"]] = None + query_params: Optional[List["ErrorDict"]] = None + + def check(self) -> bool: + """Check if any param resulted in error.""" + return any(value is not None for _, value in asdict(self).items()) + + def to_dict(self): + return {key: value for key, value in asdict(self).items() if value is not None} diff --git a/flask_pydantic/py.typed b/flask_pydantic/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/requirements/base.pip b/requirements/base.pip index 6640194..0b1c0de 100644 --- a/requirements/base.pip +++ b/requirements/base.pip @@ -1,2 +1,3 @@ Flask pydantic>=1.7 +typing-extensions>=4.1.1 diff --git a/requirements/test.pip b/requirements/test.pip index 14530b2..d5f68a5 100644 --- a/requirements/test.pip +++ b/requirements/test.pip @@ -6,3 +6,7 @@ pytest-flake8 pytest-coverage pytest-black pytest-mock + +black +flake8 +mypy diff --git a/setup.cfg b/setup.cfg index 3acda29..c828549 100644 --- a/setup.cfg +++ b/setup.cfg @@ -19,4 +19,7 @@ ignore = E121 E122 E123 E124 E125 E126 E127 E128 E711 E712 F811 F841 H803 E501 E exclude = .circleci, .github, - venv \ No newline at end of file + venv + +[isort] +profile = black diff --git a/setup.py b/setup.py index 09be0eb..1a4c595 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ def find_version(file_path: Path = VERSION_FILE_PATH) -> str: long_description=README, long_description_content_type="text/markdown", packages=["flask_pydantic"], + package_data={"flask_pydantic": ["py.typed"]}, install_requires=list(get_install_requires()), python_requires=">=3.6", classifiers=[ diff --git a/tests/conftest.py b/tests/conftest.py index e2b8ed1..6b94bb6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,15 @@ -from typing import List, Optional, Type +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar import pytest from flask import Flask, request -from flask_pydantic import validate from pydantic import BaseModel +from pydantic.generics import GenericModel + +from flask_pydantic import validate @pytest.fixture -def posts() -> List[dict]: +def posts() -> List[Dict[str, Any]]: return [ {"title": "title 1", "text": "random text", "views": 1}, {"title": "2", "text": "another text", "views": 2}, @@ -16,30 +18,33 @@ def posts() -> List[dict]: ] +class Query(BaseModel): + limit: int = 2 + min_views: Optional[int] + + @pytest.fixture def query_model() -> Type[BaseModel]: - class Query(BaseModel): - limit: int = 2 - min_views: Optional[int] - return Query +class Body(BaseModel): + search_term: str + exclude: Optional[str] + + @pytest.fixture def body_model() -> Type[BaseModel]: - class Body(BaseModel): - search_term: str - exclude: Optional[str] - return Body +class Form(BaseModel): + search_term: str + exclude: Optional[str] + + @pytest.fixture def form_model() -> Type[BaseModel]: - class Form(BaseModel): - search_term: str - exclude: Optional[str] - return Form @@ -53,23 +58,30 @@ class Post(BaseModel): return Post -@pytest.fixture -def response_model(post_model: BaseModel) -> Type[BaseModel]: - class Response(BaseModel): - results: List[post_model] - count: int +PostModelT = TypeVar("PostModelT", bound=BaseModel) - return Response +class Response(GenericModel, Generic[PostModelT]): + results: List[PostModelT] + count: int -def is_excluded(post: dict, exclude: Optional[str]) -> bool: + +@pytest.fixture +def response_model(post_model: PostModelT) -> Type[Response[PostModelT]]: + return Response[PostModelT] + + +def is_excluded(post: Dict[str, Any], exclude: Optional[str]) -> bool: if exclude is None: return False return exclude in post["title"] or exclude in post["text"] def pass_search( - post: dict, search_term: str, exclude: Optional[str], min_views: Optional[int] + post: Dict[str, Any], + search_term: str, + exclude: Optional[str], + min_views: Optional[int], ) -> bool: return ( (search_term in post["title"] or search_term in post["text"]) @@ -78,8 +90,20 @@ def pass_search( ) +QueryModelT = TypeVar("QueryModelT", bound=Query) +BodyModelT = TypeVar("BodyModelT", bound=Body) +FormModelT = TypeVar("FormModelT", bound=Form) + + @pytest.fixture -def app(posts, response_model, query_model, body_model, post_model, form_model): +def app( + posts: List[Dict[str, Any]], + response_model: Type[Response[PostModelT]], + query_model: Type[QueryModelT], + body_model: Type[BodyModelT], + post_model: Type[PostModelT], + form_model: Type[FormModelT], +) -> Flask: app = Flask("test_app") app.config["DEBUG"] = True app.config["TESTING"] = True @@ -88,7 +112,7 @@ def app(posts, response_model, query_model, body_model, post_model, form_model): @validate(query=query_model, body=body_model) def post(): query_params = request.query_params - body = request.body_params + body: BodyModelT = request.body_params results = [ post_model(**p) for p in posts @@ -98,7 +122,7 @@ def post(): @app.route("/search/kwargs", methods=["POST"]) @validate() - def post_kwargs(query: query_model, body: body_model): + def post_kwargs(query: Type[QueryModelT], body: Type[BodyModelT]): results = [ post_model(**p) for p in posts @@ -108,7 +132,7 @@ def post_kwargs(query: query_model, body: body_model): @app.route("/search/form/kwargs", methods=["POST"]) @validate() - def post_kwargs_form(query: query_model, form: form_model): + def post_kwargs_form(query: Type[QueryModelT], form: Type[BodyModelT]): results = [ post_model(**p) for p in posts diff --git a/tests/func/test_app.py b/tests/func/test_app.py index a01edd7..8660955 100644 --- a/tests/func/test_app.py +++ b/tests/func/test_app.py @@ -97,7 +97,7 @@ class PersonBulk(BaseModel): @app.route("/root_type", methods=["POST"]) @validate() def root_type(body: PersonBulk): - return {"number": len(body)} + return {"number": len(body)} # type: ignore @pytest.fixture @@ -293,7 +293,7 @@ def test_custom_headers(client): @pytest.mark.usefixtures("app_with_custom_headers_status") -def test_custom_headers(client): +def test_custom_headers_201_status(client): response = client.get("/custom_headers_status") assert response.json == {"test": 1} assert response.status_code == 201 diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 62b1baf..67ef111 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,9 +1,14 @@ -from typing import Any, List, NamedTuple, Optional, Type, Union +from typing import Any, List, NamedTuple, Optional, Type, TypeVar, Union import pytest from flask import jsonify from flask_pydantic import validate, ValidationError -from flask_pydantic.core import convert_query_params, is_iterable_of_models +from flask_pydantic.core import ( + _ensure_model_kwarg, + _get_type_generic, + convert_query_params, + is_iterable_of_models, +) from flask_pydantic.exceptions import ( InvalidIterableOfModelsException, JsonBodyParsingError, @@ -16,7 +21,7 @@ class ValidateParams(NamedTuple): body_model: Optional[Type[BaseModel]] = None query_model: Optional[Type[BaseModel]] = None form_model: Optional[Type[BaseModel]] = None - response_model: Type[BaseModel] = None + response_model: Optional[Type[BaseModel]] = None on_success_status: int = 200 request_query: ImmutableMultiDict = ImmutableMultiDict({}) request_body: Union[dict, List[dict]] = {} @@ -47,7 +52,7 @@ class RequestBodyModel(BaseModel): class FormModel(BaseModel): f1: int - f2: str = None + f2: Optional[str] = None class RequestBodyModelRoot(BaseModel): @@ -228,10 +233,11 @@ def test_validate_kwargs(self, mocker, request_ctx, parameters: ValidateParams): mock_request.form = parameters.request_form def f( - body: parameters.body_model, - query: parameters.query_model, - form: parameters.form_model, + body: parameters.body_model, # type: ignore + query: parameters.query_model, # type: ignore + form: parameters.form_model, # type: ignore ): + assert parameters.response_model is not None return parameters.response_model( **body.dict(), **query.dict(), **form.dict() ) @@ -408,6 +414,7 @@ def f() -> Any: body = mock_request.body_params.dict() if mock_request.query_params: query = mock_request.query_params.dict() + assert parameters.response_model is not None return parameters.response_model(**body, **query) response = validate( @@ -567,3 +574,65 @@ class Model(BaseModel): d: Optional[List[int]] assert convert_query_params(query_params, Model) == expected_result + + +def test_get_type_generic(): + MyGeneric = TypeVar("MyGeneric", bound=str) + MyGenericType = Type[MyGeneric] + + assert _get_type_generic(str) == str + assert _get_type_generic(MyGeneric) == str + assert _get_type_generic(MyGenericType) == str + + +def test_ensure_model_kwarg(): + class ParentModel(BaseModel): + pass + + class ChildModel(ParentModel): + pass + + def func_arg_hint_parent(body: ParentModel) -> str: + """Demonstrate less detailed hint in function type hint.""" + return "" + + assert _ensure_model_kwarg("body", ChildModel, func_arg_hint_parent) == ( + ChildModel, + True, + ), "Function has less detailed model, so `from_validate` should be chosen over it." + + def func_arg_hint_child(body: ChildModel) -> str: + """Demonstrate more detailed hint in function type hint.""" + return "" + + assert _ensure_model_kwarg("body", ParentModel, func_arg_hint_child) == ( + ChildModel, + True, + ), "Function has more detailed model, so it should be chosen over `from_validate`." + + def func_kwarg_hint_child(body: ChildModel = ChildModel()) -> str: + """Demonstrate function with kwarg instead of arg.""" + return "" + + assert _ensure_model_kwarg("body", ParentModel, func_kwarg_hint_child) == ( + ChildModel, + True, + ), "Function has more detailed model, so it should be chosen over `from_validate`." + + def func_arg_nohint(body) -> str: + """Demonstrate function with argument but without hint.""" + return "" + + assert _ensure_model_kwarg("body", ParentModel, func_arg_nohint) == ( + ParentModel, + True, + ), "There is no type hint in function but the argument itself exists." + + def func_noarg() -> str: + """Demonstrate function without the checked argument.""" + return "" + + assert _ensure_model_kwarg("body", ParentModel, func_noarg) == ( + ParentModel, + False, + ), "Function doesn't have this argument but model is declared in `from_validate`."