diff --git a/aiohttp_apispec/middlewares.py b/aiohttp_apispec/middlewares.py index 98d5d1d..841fa87 100644 --- a/aiohttp_apispec/middlewares.py +++ b/aiohttp_apispec/middlewares.py @@ -26,8 +26,10 @@ async def validation_middleware(request: web.Request, handler) -> web.Response: if not hasattr(sub_handler, "__schemas__"): return await handler(request) schemas = sub_handler.__schemas__ + response_info = sub_handler.__apispec__["responses"] else: schemas = orig_handler.__schemas__ + response_info = orig_handler.__apispec__["responses"] result = {} for schema in schemas: data = await request.app["_apispec_parser"].parse( @@ -42,4 +44,8 @@ async def validation_middleware(request: web.Request, handler) -> web.Response: result = data break request[request.app["_apispec_request_data_name"]] = result - return await handler(request) + response = await handler(request) + response_schema = response_info.get(str(response.status)) + if response_schema: + response_schema["schema"].loads(response.text) + return response diff --git a/tests/conftest.py b/tests/conftest.py index 209529a..4e1ce07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import pytest from aiohttp import web -from marshmallow import Schema, fields +from marshmallow import EXCLUDE, Schema, fields from aiohttp_apispec import ( docs, @@ -51,6 +51,11 @@ class ResponseSchema(Schema): data = fields.Dict() +class ResponseSchemaIgnoreExtra(ResponseSchema): + class Meta: + unknown = EXCLUDE + + class MyException(Exception): def __init__(self, message): self.message = message @@ -127,6 +132,16 @@ async def handler_get_echo(request): async def handler_get_variable(request): return web.json_response(request["data"]) + @request_schema(RequestSchema) + @response_schema(ResponseSchema, 200) + async def handler_post_extra_field(request): + return web.json_response(request["data"]) + + @request_schema(RequestSchema) + @response_schema(ResponseSchemaIgnoreExtra, 200) + async def handler_post_ignore_extra_field(request): + return web.json_response(request["data"]) + class ViewClass(web.View): @docs( tags=["mytag"], @@ -194,6 +209,8 @@ async def validated_view(request: web.Request): web.post("/echo", handler_post_echo), web.get("/variable/{var}", handler_get_variable), web.post("/validate/{uuid}", validated_view), + web.post("/extra_field", handler_post_extra_field), + web.post("/ignore_extra_field", handler_post_ignore_extra_field), ] ) v1.middlewares.extend([intercept_error, validation_middleware]) @@ -219,6 +236,8 @@ async def validated_view(request: web.Request): web.post("/v1/echo", handler_post_echo), web.get("/v1/variable/{var}", handler_get_variable), web.post("/v1/validate/{uuid}", validated_view), + web.post("/v1/extra_field", handler_post_extra_field), + web.post("/v1/ignore_extra_field", handler_post_ignore_extra_field), ] ) app.middlewares.extend([intercept_error, validation_middleware]) diff --git a/tests/test_documentation.py b/tests/test_documentation.py index 9c3f841..aec781a 100644 --- a/tests/test_documentation.py +++ b/tests/test_documentation.py @@ -154,6 +154,10 @@ async def test_app_swagger_json(aiohttp_app, example_for_request_schema): "properties": {"data": {"type": "object"}, "msg": {"type": "string"}}, "type": "object", }, + "ResponseSchemaIgnoreExtra": { + "properties": {"data": {"type": "object"}, "msg": {"type": "string"}}, + "type": "object", + }, }, sort_keys=True, ) diff --git a/tests/test_web_app.py b/tests/test_web_app.py index 470569b..d30ceaf 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -162,3 +162,14 @@ async def test_swagger_path(aiohttp_app): async def test_swagger_static(aiohttp_app): assert (await aiohttp_app.get("/static/swagger/swagger-ui.css")).status == 200 \ or (await aiohttp_app.get("/v1/static/swagger/swagger-ui.css")).status == 200 + + +async def test_response_extra_fields(aiohttp_app): + res = await aiohttp_app.post("/v1/extra_field", json={"id": 1, "foo": 2}) + assert res.status == 500 + + +async def test_response_ignore_extra_fields(aiohttp_app): + res = await aiohttp_app.post("/v1/ignore_extra_field", json={"id": 1, "foo": 2}) + assert res.status == 200 + assert (await res.json()) == {"id": 1}