Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for OpenAPI 3.0 specification. #123

Merged
merged 2 commits into from
Jan 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions aiohttp_apispec/aiohttp_apispec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
import enum
import json
import os
from pathlib import Path
from typing import Awaitable, Callable
import json
from typing import Awaitable, Callable, Union

from aiohttp import web
from aiohttp.hdrs import METH_ALL, METH_ANY
from apispec import APISpec
Expand Down Expand Up @@ -36,6 +38,14 @@ def resolver(schema):
return name


class OpenApiVersion(str, enum.Enum):
V20 = "2.0"
V300 = "3.0.0"
V301 = "3.0.1"
V302 = "3.0.2"
V303 = "3.0.3"


class AiohttpApiSpec:
def __init__(
self,
Expand All @@ -48,11 +58,23 @@ def __init__(
in_place=False,
prefix='',
schema_name_resolver=resolver,
openapi_version=None,
**kwargs,
):
openapi_version = openapi_version or OpenApiVersion.V20
try:
openapi_version = OpenApiVersion(openapi_version)
except ValueError:
raise ValueError(
f"Invalid `openapi_version`: {openapi_version!r}"
) from None

self.plugin = MarshmallowPlugin(schema_name_resolver=schema_name_resolver)
self.spec = APISpec(plugins=(self.plugin,), openapi_version="2.0", **kwargs)
self.spec = APISpec(
plugins=(self.plugin,),
openapi_version=openapi_version.value,
**kwargs,
)

self.url = url
self.swagger_path = swagger_path
Expand Down Expand Up @@ -200,7 +222,14 @@ def _update_paths(self, data: dict, method: str, url_path: str):
for k, v in raw_parameters.items()
if k in VALID_RESPONSE_FIELDS
}
updated_params['schema'] = actual_params["schema"]
if self.spec.components.openapi_version.major < 3:
updated_params['schema'] = actual_params["schema"]
else:
updated_params["content"] = {
"application/json": {
"schema": actual_params["schema"],
},
}
for extra_info in ("description", "headers", "examples"):
if extra_info in actual_params:
updated_params[extra_info] = actual_params[extra_info]
Expand Down Expand Up @@ -247,6 +276,7 @@ def setup_aiohttp_apispec(
in_place: bool = False,
prefix: str = '',
schema_name_resolver: Callable = resolver,
openapi_version: Union[str, OpenApiVersion] = OpenApiVersion.V20,
**kwargs,
) -> AiohttpApiSpec:
"""
Expand Down Expand Up @@ -304,6 +334,7 @@ async def index(request):
If True, be sure all routes are added to router
:param prefix: prefix to add to all registered routes
:param schema_name_resolver: custom schema_name_resolver for MarshmallowPlugin.
:param openapi_version: version of OpenAPI schema
:param kwargs: any apispec.APISpec kwargs
:return: return instance of AiohttpApiSpec class
:rtype: AiohttpApiSpec
Expand All @@ -320,5 +351,6 @@ async def index(request):
in_place=in_place,
prefix=prefix,
schema_name_resolver=schema_name_resolver,
openapi_version=openapi_version,
**kwargs,
)