From 687bda3001ff2922aa2431bd1b9101e01abf5cb3 Mon Sep 17 00:00:00 2001 From: kushagra Date: Sun, 24 Dec 2023 16:16:31 +0530 Subject: [PATCH 1/3] Added flask app wrapper to create custom config for foca --- foca/api/register_openapi.py | 2 +- foca/config/config_parser.py | 10 +++--- foca/errors/exceptions.py | 3 +- foca/factories/celery_app.py | 8 ++--- foca/factories/connexion_app.py | 4 ++- foca/factories/flask_app.py | 35 +++++++++++++++++++ foca/models/config.py | 7 ++-- .../access_control/access_control_server.py | 35 ++++++++----------- .../foca_casbin_adapter/adapter.py | 6 ++-- foca/security/auth.py | 13 ++++--- tests/errors/test_errors.py | 13 ++++--- tests/factories/test_flask_app.py | 13 +++++++ tests/mock_data.py | 2 +- .../foca_casbin_adapter/test_adapter.py | 2 +- tests/utils/test_misc.py | 4 +-- 15 files changed, 108 insertions(+), 49 deletions(-) create mode 100644 foca/factories/flask_app.py create mode 100644 tests/factories/test_flask_app.py diff --git a/foca/api/register_openapi.py b/foca/api/register_openapi.py index 663c8a43..dfb7caae 100644 --- a/foca/api/register_openapi.py +++ b/foca/api/register_openapi.py @@ -68,7 +68,7 @@ def register_openapi( # OpenAPI 3 sec_schemes = spec_parsed.get( 'components', {'securitySchemes': {}} - ).get('securitySchemes', {}) # type: ignore + ).get('securitySchemes', {}) for sec_scheme in sec_schemes.values(): sec_scheme[key] = val logger.debug(f"Added security fields: {spec.add_security_fields}") diff --git a/foca/config/config_parser.py b/foca/config/config_parser.py index 10a1ec43..c4d2dc85 100644 --- a/foca/config/config_parser.py +++ b/foca/config/config_parser.py @@ -4,7 +4,7 @@ import logging from logging.config import dictConfig from pathlib import Path -from typing import (Dict, Optional) +from typing import (Dict, Optional, Callable) from addict import Dict as Addict from pydantic import BaseModel @@ -157,7 +157,10 @@ def parse_custom_config(self, model: str) -> BaseModel: module = Path(model).stem model_class = Path(model).suffix[1:] try: - model_class = getattr(import_module(module), model_class) + model_class_instance: Callable = getattr( + import_module(module), + model_class + ) except ModuleNotFoundError: raise ValueError( f"failed validating custom configuration: module '{module}' " @@ -169,8 +172,7 @@ def parse_custom_config(self, model: str) -> BaseModel: f"has no class {model_class} or could not be imported" ) try: - custom_config = model_class( # type: ignore[operator] - **self.config.custom) # type: ignore[attr-defined] + custom_config = model_class_instance(**self.config.custom) except Exception as exc: raise ValueError( "failed validating custom configuration: provided custom " diff --git a/foca/errors/exceptions.py b/foca/errors/exceptions.py index 0a62223e..e831d2e6 100644 --- a/foca/errors/exceptions.py +++ b/foca/errors/exceptions.py @@ -211,7 +211,8 @@ def _problem_handler_json(exception: Exception) -> Response: JSON-formatted error response. """ # Look up exception & get status code - conf = current_app.config.foca.exceptions # type: ignore[attr-defined] + foca_conf = getattr(current_app.config, 'foca') + conf = foca_conf.exceptions exc = type(exception) if exc not in conf.mapping: exc = Exception diff --git a/foca/factories/celery_app.py b/foca/factories/celery_app.py index 271783ef..26abd19f 100644 --- a/foca/factories/celery_app.py +++ b/foca/factories/celery_app.py @@ -3,14 +3,14 @@ from inspect import stack import logging -from flask import Flask +from connexion import FlaskApp from celery import Celery # Get logger instance logger = logging.getLogger(__name__) -def create_celery_app(app: Flask) -> Celery: +def create_celery_app(app: FlaskApp) -> Celery: """Create and configure Celery application instance. Args: @@ -19,7 +19,7 @@ def create_celery_app(app: Flask) -> Celery: Returns: Celery application instance. """ - conf = app.config.foca.jobs # type: ignore[attr-defined] + conf = app.config.foca.jobs # Instantiate Celery app celery = Celery( @@ -32,7 +32,7 @@ def create_celery_app(app: Flask) -> Celery: logger.debug(f"Celery app created from '{calling_module}'.") # Update Celery app configuration with Flask app configuration - setattr(celery.conf, 'foca', app.config.foca) # type: ignore[attr-defined] + setattr(celery.conf, 'foca', app.config.foca) logger.debug('Celery app configured.') class ContextTask(celery.Task): # type: ignore diff --git a/foca/factories/connexion_app.py b/foca/factories/connexion_app.py index b39b7f6c..04881ffe 100644 --- a/foca/factories/connexion_app.py +++ b/foca/factories/connexion_app.py @@ -7,6 +7,7 @@ from connexion import App from foca.models.config import Config +from foca.factories.flask_app import create_flask_app # Get logger instance logger = logging.getLogger(__name__) @@ -26,6 +27,7 @@ def create_connexion_app(config: Optional[Config] = None) -> App: __name__, skip_error_handlers=True, ) + app.app = create_flask_app() calling_module = ':'.join([stack()[1].filename, stack()[1].function]) logger.debug(f"Connexion app created from '{calling_module}'.") @@ -71,7 +73,7 @@ def __add_config_to_connexion_app( logger.debug('* {}: {}'.format(key, value)) # Add user configuration to Flask app config - setattr(app.app.config, 'foca', config) + app.app.config.foca = config logger.debug('Connexion app configured.') return app diff --git a/foca/factories/flask_app.py b/foca/factories/flask_app.py new file mode 100644 index 00000000..04505658 --- /dev/null +++ b/foca/factories/flask_app.py @@ -0,0 +1,35 @@ +"""Factory for creating and configuring Connexion application instances.""" + +from flask import Config, Flask +from inspect import stack +import logging +from typing import Optional + +from foca.models.config import Config as FocaConfig + +# Get logger instance +logger = logging.getLogger(__name__) + + +class FocaFlaskAppConfig(Config): + """Custom config class wrapper to include foca as an attribute + within config. + """ + foca: Optional[FocaConfig] + + +def create_flask_app() -> Flask: + """Create and configure Flask application instance for connexion + context. + + Returns: + Flask application with custom foca config configured. + """ + + flask_app = Flask(__name__) + flask_app.config.from_object(FocaFlaskAppConfig) + + calling_module = ':'.join([stack()[1].filename, stack()[1].function]) + logger.debug(f"Flask app created from '{calling_module}'.") + + return flask_app diff --git a/foca/models/config.py b/foca/models/config.py index 6b41bfa9..9bd0d41b 100644 --- a/foca/models/config.py +++ b/foca/models/config.py @@ -52,7 +52,7 @@ def _get_by_path( Returns: Value of innermost key. """ - return reduce(operator.getitem, key_sequence, obj) # type: ignore + return reduce(operator.getitem, key_sequence, obj) class ExceptionLoggingEnum(Enum): @@ -1200,6 +1200,7 @@ class Config(FOCABaseConfig): db: Database config parameters. jobs: Background job config parameters. log: Logger config parameters. + custom: Custom config parameters. (Added by consumers) Attributes: server: Server config parameters. @@ -1209,6 +1210,7 @@ class Config(FOCABaseConfig): db: Database config parameters. jobs: Background job config parameters. log: Logger config parameters. + custom: Custom config parameters. (Added by consumers) Raises: pydantic.ValidationError: The class was instantianted with an illegal @@ -1248,7 +1250,7 @@ class 'werkzeug.exceptions.BadGateway'>: {'title': 'Bad Gateway', 'status': 50\ time}: {levelname:<8}] {message} [{name}]')}, handlers={'console': LogHandlerC\ onfig(class_handler='logging.StreamHandler', level=20, formatter='standard', s\ tream='ext://sys.stderr')}, root=LogRootConfig(level=10, handlers=['console'])\ -)) +custom=None)) """ server: ServerConfig = ServerConfig() exceptions: ExceptionConfig = ExceptionConfig() @@ -1257,6 +1259,7 @@ class 'werkzeug.exceptions.BadGateway'>: {'title': 'Bad Gateway', 'status': 50\ db: Optional[MongoConfig] = None jobs: Optional[JobsConfig] = None log: LogConfig = LogConfig() + custom: Any = None class Config: """Configuration for Pydantic model class.""" diff --git a/foca/security/access_control/access_control_server.py b/foca/security/access_control/access_control_server.py index 6521c017..686440d4 100644 --- a/foca/security/access_control/access_control_server.py +++ b/foca/security/access_control/access_control_server.py @@ -10,6 +10,7 @@ from foca.utils.logging import log_traffic from foca.errors.exceptions import BadRequest +from foca.models.config import AccessControlConfig logger = logging.getLogger(__name__) @@ -62,16 +63,13 @@ def putPermission( """ request_json = request.json if isinstance(request_json, dict): - app_config = current_app.config + foca_conf = getattr(current_app.config, 'foca') try: - security_conf = \ - app_config.foca.security # type: ignore[attr-defined] - access_control_config = \ - security_conf.access_control # type: ignore[attr-defined] + access_control_conf: AccessControlConfig = \ + foca_conf.security.access_control db_coll_permission: Collection = ( - app_config.foca.db.dbs[ # type: ignore[attr-defined] - access_control_config.db_name] - .collections[access_control_config.collection_name].client + foca_conf.db.dbs[access_control_conf.db_name] + .collections[access_control_conf.collection_name].client ) permission_data = request_json.get("rule", {}) @@ -102,11 +100,10 @@ def getAllPermissions(limit=None) -> List[Dict]: Returns: List of permission dicts. """ - app_config = current_app.config - access_control_config = \ - app_config.foca.security.access_control # type: ignore[attr-defined] + foca_conf = getattr(current_app.config, 'foca') + access_control_config = foca_conf.security.access_control db_coll_permission: Collection = ( - app_config.foca.db.dbs[ # type: ignore[attr-defined] + foca_conf.db.dbs[ access_control_config.db_name ].collections[access_control_config.collection_name].client ) @@ -145,11 +142,10 @@ def getPermission( Returns: Permission data for the given id. """ - app_config = current_app.config - access_control_config = \ - app_config.foca.security.access_control # type: ignore[attr-defined] + foca_conf = getattr(current_app.config, 'foca') + access_control_config = foca_conf.security.access_control db_coll_permission: Collection = ( - app_config.foca.db.dbs[ # type: ignore[attr-defined] + foca_conf.db.dbs[ access_control_config.db_name ].collections[access_control_config.collection_name].client ) @@ -181,11 +177,10 @@ def deletePermission( Returns: Delete permission identifier. """ - app_config = current_app.config - access_control_config = \ - app_config.foca.security.access_control # type: ignore[attr-defined] + foca_conf = getattr(current_app.config, 'foca') + access_control_config = foca_conf.security.access_control db_coll_permission: Collection = ( - app_config.foca.db.dbs[ # type: ignore[attr-defined] + foca_conf.db.dbs[ access_control_config.db_name ].collections[access_control_config.collection_name].client ) diff --git a/foca/security/access_control/foca_casbin_adapter/adapter.py b/foca/security/access_control/foca_casbin_adapter/adapter.py index 76bcf3ca..00eb1a3c 100644 --- a/foca/security/access_control/foca_casbin_adapter/adapter.py +++ b/foca/security/access_control/foca_casbin_adapter/adapter.py @@ -2,7 +2,7 @@ from casbin import persist from casbin.model import Model -from typing import (List, Optional) +from typing import (Any, List, Optional) from pymongo import MongoClient from foca.security.access_control.foca_casbin_adapter.casbin_rule import ( @@ -170,10 +170,10 @@ def remove_filtered_policy( if not (1 <= field_index + len(field_values) <= 6): return False - query = {} + query: dict[str, Any] = {} for index, value in enumerate(field_values): query[f"v{index + field_index}"] = value - query["ptype"] = ptype # type: ignore[assignment] + query["ptype"] = ptype results = self._collection.delete_many(query) return results.deleted_count > 0 diff --git a/foca/security/auth.py b/foca/security/auth.py index 5d389e41..11994fe9 100644 --- a/foca/security/auth.py +++ b/foca/security/auth.py @@ -2,11 +2,11 @@ from connexion.exceptions import Unauthorized import logging -from typing import (Dict, Iterable, List, Optional) +from typing import (Any, cast, Dict, Iterable, List, Optional) from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey -from flask import current_app, request +from flask import (current_app, request) import jwt from jwt.exceptions import InvalidKeyError import requests @@ -14,6 +14,8 @@ import json from werkzeug.datastructures import ImmutableMultiDict +from foca.models.config import Config + # Get logger instance logger = logging.getLogger(__name__) @@ -36,7 +38,8 @@ def validate_token(token: str) -> Dict: oidc_config_claim_public_keys: str = 'jwks_uri' # Fetch security parameters - conf = current_app.config.foca.security.auth # type: ignore[attr-defined] + foca_conf = cast(Config, getattr(current_app.config, 'foca')) + conf = foca_conf.security.auth add_key_to_claims: bool = conf.add_key_to_claims allow_expired: bool = conf.allow_expired audience: Optional[Iterable[str]] = conf.audience @@ -245,10 +248,10 @@ def _validate_jwt_public_key( validation_options['verify_exp'] = False # Try public keys one after the other - used_key: Dict = {} + used_key: Any = {} claims = {} for key in public_keys.values(): - used_key = key # type: ignore[assignment] + used_key = key # Decode JWT and validate via public key try: diff --git a/tests/errors/test_errors.py b/tests/errors/test_errors.py index 93eb91dd..6e3887cf 100644 --- a/tests/errors/test_errors.py +++ b/tests/errors/test_errors.py @@ -18,6 +18,7 @@ _subset_nested_dict, ) from foca.models.config import Config +from foca.factories.flask_app import FocaFlaskAppConfig EXCEPTION_INSTANCE = Exception() INVALID_LOG_FORMAT = 'unknown_log_format' @@ -103,7 +104,8 @@ def test__exclude_key_nested_dict(): def test__problem_handler_json(): """Test problem handler with instance of custom, unlisted error.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + app.config.from_object(FocaFlaskAppConfig) + app.config.foca = Config() EXPECTED_RESPONSE = app.config.foca.exceptions.mapping[Exception] with app.app_context(): res = _problem_handler_json(UnknownException()) @@ -117,7 +119,8 @@ def test__problem_handler_json(): def test__problem_handler_json_no_fallback_exception(): """Test problem handler; unlisted error without fallback.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + app.config.from_object(FocaFlaskAppConfig) + app.config.foca = Config() del app.config.foca.exceptions.mapping[Exception] with app.app_context(): res = _problem_handler_json(UnknownException()) @@ -131,7 +134,8 @@ def test__problem_handler_json_no_fallback_exception(): def test__problem_handler_json_with_public_members(): """Test problem handler with public members.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + app.config.from_object(FocaFlaskAppConfig) + app.config.foca = Config() app.config.foca.exceptions.public_members = PUBLIC_MEMBERS with app.app_context(): res = _problem_handler_json(UnknownException()) @@ -143,7 +147,8 @@ def test__problem_handler_json_with_public_members(): def test__problem_handler_json_with_private_members(): """Test problem handler with private members.""" app = Flask(__name__) - setattr(app.config, 'foca', Config()) + app.config.from_object(FocaFlaskAppConfig) + app.config.foca = Config() app.config.foca.exceptions.private_members = PRIVATE_MEMBERS with app.app_context(): res = _problem_handler_json(UnknownException()) diff --git a/tests/factories/test_flask_app.py b/tests/factories/test_flask_app.py new file mode 100644 index 00000000..45d4233a --- /dev/null +++ b/tests/factories/test_flask_app.py @@ -0,0 +1,13 @@ +"""Tests for foca.factories.flask_app.""" +from flask import Flask + +from foca.factories.flask_app import create_flask_app +from foca.models.config import Config + + +def test_create_flask_app(): + """Test Connexion app creation without config.""" + flask_app = create_flask_app() + assert isinstance(flask_app, Flask) + flask_app.config.foca = Config() + assert isinstance(flask_app.config.foca, Config) diff --git a/tests/mock_data.py b/tests/mock_data.py index 4b830286..c6749bb2 100644 --- a/tests/mock_data.py +++ b/tests/mock_data.py @@ -15,7 +15,7 @@ } MONGO_CONFIG = { "host": "mongodb", - "port": 12345, + "port": 27017, "dbs": { "access_control_db": DB_CONFIG, }, diff --git a/tests/security/access_control/foca_casbin_adapter/test_adapter.py b/tests/security/access_control/foca_casbin_adapter/test_adapter.py index 4bb4827f..d38027ad 100644 --- a/tests/security/access_control/foca_casbin_adapter/test_adapter.py +++ b/tests/security/access_control/foca_casbin_adapter/test_adapter.py @@ -35,7 +35,7 @@ class TestAdapter(TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.db_name = "casbin_test" - self.db_port = 12345 + self.db_port = 27017 def setUp(self): self.clear_db() diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index a5f9a40e..fc2c029c 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -49,7 +49,7 @@ def test_evaluation_error(self): """Evaulation of `length` raises an exception.""" charset = int with pytest.raises(TypeError): - generate_id(charset=charset) # type: ignore + generate_id(charset=charset) def test_length(self): """Non-default argument to `length`.""" @@ -61,7 +61,7 @@ def test_length_not_int(self): """Argument to `length` is not an integer.""" length = "" with pytest.raises(TypeError): - generate_id(length=length) # type: ignore + generate_id(length=length) def test_length_not_positive(self): """Argument to `length` is not a positive integer.""" From 05cbb2363c774ba223d03afdd83c44c859ecf483 Mon Sep 17 00:00:00 2001 From: kushagra Date: Sun, 24 Dec 2023 16:54:49 +0530 Subject: [PATCH 2/3] Fix dict issue mypy --- foca/security/access_control/foca_casbin_adapter/adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/foca/security/access_control/foca_casbin_adapter/adapter.py b/foca/security/access_control/foca_casbin_adapter/adapter.py index 00eb1a3c..96db470a 100644 --- a/foca/security/access_control/foca_casbin_adapter/adapter.py +++ b/foca/security/access_control/foca_casbin_adapter/adapter.py @@ -2,7 +2,7 @@ from casbin import persist from casbin.model import Model -from typing import (Any, List, Optional) +from typing import (Any, Dict, List, Optional) from pymongo import MongoClient from foca.security.access_control.foca_casbin_adapter.casbin_rule import ( @@ -170,7 +170,7 @@ def remove_filtered_policy( if not (1 <= field_index + len(field_values) <= 6): return False - query: dict[str, Any] = {} + query: Dict[str, Any] = {} for index, value in enumerate(field_values): query[f"v{index + field_index}"] = value From 70bae415cc2b6df369dcfb57cee83947a4076c36 Mon Sep 17 00:00:00 2001 From: kushagra Date: Sun, 24 Dec 2023 17:00:20 +0530 Subject: [PATCH 3/3] Revert db ports altered for testing --- tests/mock_data.py | 2 +- .../security/access_control/foca_casbin_adapter/test_adapter.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/mock_data.py b/tests/mock_data.py index c6749bb2..4b830286 100644 --- a/tests/mock_data.py +++ b/tests/mock_data.py @@ -15,7 +15,7 @@ } MONGO_CONFIG = { "host": "mongodb", - "port": 27017, + "port": 12345, "dbs": { "access_control_db": DB_CONFIG, }, diff --git a/tests/security/access_control/foca_casbin_adapter/test_adapter.py b/tests/security/access_control/foca_casbin_adapter/test_adapter.py index d38027ad..4bb4827f 100644 --- a/tests/security/access_control/foca_casbin_adapter/test_adapter.py +++ b/tests/security/access_control/foca_casbin_adapter/test_adapter.py @@ -35,7 +35,7 @@ class TestAdapter(TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.db_name = "casbin_test" - self.db_port = 27017 + self.db_port = 12345 def setUp(self): self.clear_db()