diff --git a/aiohttp_apispec/aiohttp_apispec.py b/aiohttp_apispec/aiohttp_apispec.py index e96e5bd..cd27266 100644 --- a/aiohttp_apispec/aiohttp_apispec.py +++ b/aiohttp_apispec/aiohttp_apispec.py @@ -1,8 +1,10 @@ import copy +import enum +import json import os from pathlib import Path -from typing import Awaitable, Callable -import json +from typing import Awaitable, Callable, Union + from aiohttp import web from aiohttp.hdrs import METH_ALL, METH_ANY from apispec import APISpec @@ -36,6 +38,14 @@ def resolver(schema): return name +class OpenApiVersion(str, enum.Enum): + V20 = "2.0" + V300 = "3.0.0" + V301 = "3.0.1" + V302 = "3.0.2" + V303 = "3.0.3" + + class AiohttpApiSpec: def __init__( self, @@ -48,11 +58,23 @@ def __init__( in_place=False, prefix='', schema_name_resolver=resolver, + openapi_version=None, **kwargs, ): + openapi_version = openapi_version or OpenApiVersion.V20 + try: + openapi_version = OpenApiVersion(openapi_version) + except ValueError: + raise ValueError( + f"Invalid `openapi_version`: {openapi_version!r}" + ) from None self.plugin = MarshmallowPlugin(schema_name_resolver=schema_name_resolver) - self.spec = APISpec(plugins=(self.plugin,), openapi_version="2.0", **kwargs) + self.spec = APISpec( + plugins=(self.plugin,), + openapi_version=openapi_version.value, + **kwargs, + ) self.url = url self.swagger_path = swagger_path @@ -200,7 +222,14 @@ def _update_paths(self, data: dict, method: str, url_path: str): for k, v in raw_parameters.items() if k in VALID_RESPONSE_FIELDS } - updated_params['schema'] = actual_params["schema"] + if self.spec.components.openapi_version.major < 3: + updated_params['schema'] = actual_params["schema"] + else: + updated_params["content"] = { + "application/json": { + "schema": actual_params["schema"], + }, + } for extra_info in ("description", "headers", "examples"): if extra_info in actual_params: updated_params[extra_info] = actual_params[extra_info] @@ -247,6 +276,7 @@ def setup_aiohttp_apispec( in_place: bool = False, prefix: str = '', schema_name_resolver: Callable = resolver, + openapi_version: Union[str, OpenApiVersion] = OpenApiVersion.V20, **kwargs, ) -> AiohttpApiSpec: """ @@ -304,6 +334,7 @@ async def index(request): If True, be sure all routes are added to router :param prefix: prefix to add to all registered routes :param schema_name_resolver: custom schema_name_resolver for MarshmallowPlugin. + :param openapi_version: version of OpenAPI schema :param kwargs: any apispec.APISpec kwargs :return: return instance of AiohttpApiSpec class :rtype: AiohttpApiSpec @@ -320,5 +351,6 @@ async def index(request): in_place=in_place, prefix=prefix, schema_name_resolver=schema_name_resolver, + openapi_version=openapi_version, **kwargs, )