diff --git a/AUTHORS.rst b/AUTHORS.rst index 754e6b8d..ca42b1c7 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -25,3 +25,4 @@ Contributors (chronological) - Choudhury Noor `@Cnoor0171 `_ - Dmitry Erlikh `@derlikh-smart `_ - 0x78f1935 `@0x78f1935 `_ +- Cory Laughlin `@Aesonus `_ diff --git a/flask_smorest/__init__.py b/flask_smorest/__init__.py index 317592e4..4f3104dd 100644 --- a/flask_smorest/__init__.py +++ b/flask_smorest/__init__.py @@ -88,6 +88,9 @@ def register_blueprint(self, blp, *, parameters=None, **options): self._app.register_blueprint(blp, **options) + for bp_plugin in getattr(blp, "_smore_plugins", []): + bp_plugin.visit_api(self) + # Register views in API documentation for this resource blp.register_views_in_doc( self, diff --git a/flask_smorest/blueprint.py b/flask_smorest/blueprint.py index c8c001aa..a256dbd4 100644 --- a/flask_smorest/blueprint.py +++ b/flask_smorest/blueprint.py @@ -72,6 +72,9 @@ def __init__(self, *args, **kwargs): self.description = kwargs.pop("description", "") + # This is where smore plugins are stored + self._smore_plugins = kwargs.pop("smore_plugins", []) + super().__init__(*args, **kwargs) # _docs stores information used at init time to produce documentation. @@ -97,6 +100,7 @@ def __init__(self, *args, **kwargs): self._prepare_response_doc, self._prepare_pagination_doc, self._prepare_etag_doc, + *[plugin.register_method_docs for plugin in self._smore_plugins], ] def add_url_rule( diff --git a/flask_smorest/plugin/__init__.py b/flask_smorest/plugin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/flask_smorest/plugin/abc.py b/flask_smorest/plugin/abc.py new file mode 100644 index 00000000..3de89c58 --- /dev/null +++ b/flask_smorest/plugin/abc.py @@ -0,0 +1,27 @@ +import abc + + +class Plugin(abc.ABC): + """Abstract base class to structure smore plugins""" + + @abc.abstractmethod + def register_method_docs(self, doc, doc_info, *, api, spec, **kwargs): + """ + Call when the views are registered in doc + + :param dict doc: The current operation doc + :param dict doc_info: Doc info stored by decorators + :param Api api: The Api() instance + :param APISpec spec: The APISpec() instance + """ + + @abc.abstractmethod + def visit_api(self, api, **kwargs): + """ + Visit the api + + This should be used to register toplevel objects on the spec + + :param api: The APISpec() instance + """ + pass diff --git a/flask_smorest/plugin/built_in.py b/flask_smorest/plugin/built_in.py new file mode 100644 index 00000000..1724109b --- /dev/null +++ b/flask_smorest/plugin/built_in.py @@ -0,0 +1,33 @@ +from flask_smorest import Api + +from ..utils import deepupdate +from . import abc + + +class APIKeySecurityPlugin(abc.Plugin): + def __init__(self, schema_name, parameter_name, in_="header") -> None: + self._schema_name = schema_name + self._parameter_name = parameter_name + self._in = in_ + + def security(self, keys): + def decorator(func): + func._apidoc = deepupdate( + getattr(func, "_apidoc", {}), {"security": [{key: [] for key in keys}]} + ) + return func + + return decorator + + def register_method_docs(self, doc, doc_info, *, api, spec, **kwargs): + # No need to attempt to add "security" to doc if it is not in doc_info + if "security" in doc_info: + doc = deepupdate(doc, {"security": doc_info["security"]}) + return doc + + def visit_api(self, api: Api, **kwargs) -> None: + """Visits the api and registers security objects""" + api.spec.components.security_scheme( + self._schema_name, + {"type": "apiKey", "in": self._in, "name": self._parameter_name}, + ) diff --git a/tests/test_plugin.py b/tests/test_plugin.py new file mode 100644 index 00000000..6e7430f5 --- /dev/null +++ b/tests/test_plugin.py @@ -0,0 +1,39 @@ +import pytest + +from flask_smorest import Api +from flask_smorest.blueprint import Blueprint +from flask_smorest.plugin.built_in import APIKeySecurityPlugin + + +class TestSecurityPlugin: + @pytest.fixture + def security_plugin(self): + return APIKeySecurityPlugin("testApiKey", "X-API-Key") + + @pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2")) + def test_spec_contains_security_requirement( + self, app, security_plugin, openapi_version + ): + app.config["OPENAPI_VERSION"] = openapi_version + api = Api(app) + + blp = Blueprint( + "test", __name__, url_prefix="/test", smore_plugins=[security_plugin] + ) + + @blp.route("/") + @security_plugin.security(["testApiKey"]) + def func(): + """Dummy view func""" + + api.register_blueprint(blp) + + spec = api.spec.to_dict() + assert spec["paths"]["/test/"]["get"]["security"] == [{"testApiKey": []}] + if openapi_version == "3.0.2": + security_schemes = spec["components"]["securitySchemes"] + else: # Version 2.0 + security_schemes = spec["securityDefinitions"] + assert security_schemes == { + "testApiKey": {"type": "apiKey", "in": "header", "name": "X-API-Key"} + }