From c7e645fd2d7e6d7367d961443817508b3c9e1dd8 Mon Sep 17 00:00:00 2001 From: Adrian Galvan Date: Tue, 19 Nov 2024 10:30:44 -0800 Subject: [PATCH] Adding retries to new database task sessions (#5448) --- CHANGELOG.md | 1 + src/fides/api/api/deps.py | 3 ++ src/fides/api/db/session.py | 42 ++++++++++++--- src/fides/api/tasks/__init__.py | 34 +++++++++++- src/fides/config/database_settings.py | 24 +++++++++ tests/ops/tasks/test_celery.py | 37 +------------ tests/ops/tasks/test_database_task.py | 78 +++++++++++++++++++++++++++ 7 files changed, 173 insertions(+), 46 deletions(-) create mode 100644 tests/ops/tasks/test_database_task.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c02ab5679..ecbc608f1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The types of changes are: ### Changed - Allow hiding systems via a `hidden` parameter and add two flags on the `/system` api endpoint; `show_hidden` and `dnd_relevant`, to display only systems with integrations [#5484](https://github.com/ethyca/fides/pull/5484) - Updated POST taxonomy endpoints to handle creating resources without specifying fides_key [#5468](https://github.com/ethyca/fides/pull/5468) +- Disabled connection pooling for task workers and added retries and keep-alive configurations for database connections [#5448](https://github.com/ethyca/fides/pull/5448) ### Developer Experience - Fixing BigQuery integration tests [#5491](https://github.com/ethyca/fides/pull/5491) diff --git a/src/fides/api/api/deps.py b/src/fides/api/api/deps.py index 16867b17d3..0bea65a248 100644 --- a/src/fides/api/api/deps.py +++ b/src/fides/api/api/deps.py @@ -46,6 +46,9 @@ def get_api_session() -> Session: config=CONFIG, pool_size=CONFIG.database.api_engine_pool_size, max_overflow=CONFIG.database.api_engine_max_overflow, + keepalives_idle=CONFIG.database.api_engine_keepalives_idle, + keepalives_interval=CONFIG.database.api_engine_keepalives_interval, + keepalives_count=CONFIG.database.api_engine_keepalives_count, ) SessionLocal = get_db_session(CONFIG, engine=_engine) db = SessionLocal() diff --git a/src/fides/api/db/session.py b/src/fides/api/db/session.py index 0b3f09700d..de96141286 100644 --- a/src/fides/api/db/session.py +++ b/src/fides/api/db/session.py @@ -1,10 +1,13 @@ from __future__ import annotations +from typing import Any, Dict + from loguru import logger from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import NullPool from fides.api.common_exceptions import MissingConfig from fides.api.db.util import custom_json_deserializer, custom_json_serializer @@ -17,6 +20,10 @@ def get_db_engine( database_uri: str | URL | None = None, pool_size: int = 50, max_overflow: int = 50, + keepalives_idle: int | None = None, + keepalives_interval: int | None = None, + keepalives_count: int | None = None, + disable_pooling: bool = False, ) -> Engine: """Return a database engine. @@ -32,14 +39,33 @@ def get_db_engine( database_uri = config.database.sqlalchemy_test_database_uri else: database_uri = config.database.sqlalchemy_database_uri - return create_engine( - database_uri, - pool_pre_ping=True, - pool_size=pool_size, - max_overflow=max_overflow, - json_serializer=custom_json_serializer, - json_deserializer=custom_json_deserializer, - ) + + engine_args: Dict[str, Any] = { + "json_serializer": custom_json_serializer, + "json_deserializer": custom_json_deserializer, + } + + # keepalives settings + connect_args = {} + if keepalives_idle: + connect_args["keepalives_idle"] = keepalives_idle + if keepalives_interval: + connect_args["keepalives_interval"] = keepalives_interval + if keepalives_count: + connect_args["keepalives_count"] = keepalives_count + + if connect_args: + connect_args["keepalives"] = 1 + engine_args["connect_args"] = connect_args + + if disable_pooling: + engine_args["poolclass"] = NullPool + else: + engine_args["pool_pre_ping"] = True + engine_args["pool_size"] = pool_size + engine_args["max_overflow"] = max_overflow + + return create_engine(database_uri, **engine_args) def get_db_session( diff --git a/src/fides/api/tasks/__init__.py b/src/fides/api/tasks/__init__.py index 621c2ea0d2..e757052ac1 100644 --- a/src/fides/api/tasks/__init__.py +++ b/src/fides/api/tasks/__init__.py @@ -2,7 +2,15 @@ from celery import Celery, Task from loguru import logger +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from fides.api.db.session import get_db_engine, get_db_session from fides.api.util.logger import setup as setup_logging @@ -11,6 +19,7 @@ MESSAGING_QUEUE_NAME = "fidesops.messaging" PRIVACY_PREFERENCES_QUEUE_NAME = "fides.privacy_preferences" # This queue is used in Fidesplus for saving privacy preferences and notices served +NEW_SESSION_RETRIES = 5 autodiscover_task_locations: List[str] = [ "fides.api.tasks", @@ -20,10 +29,29 @@ ] +def log_retry_attempt(retry_state: RetryCallState) -> None: + """Log database connection retry attempts.""" + + logger.warning( + "Database connection attempt {} failed. Retrying in {} seconds...", + retry_state.attempt_number, + retry_state.next_action.sleep, # type: ignore[union-attr] + ) + + class DatabaseTask(Task): # pylint: disable=W0223 _task_engine = None _sessionmaker = None + # This retry will attempt to connect 5 times with an exponential backoff (2, 4, 8, 16 seconds between each attempt). + # The original error will be re-raised if the retries are successful. All attempts are shown in the logs. + @retry( + stop=stop_after_attempt(NEW_SESSION_RETRIES), + wait=wait_exponential(multiplier=1, min=1), + retry=retry_if_exception_type(OperationalError), + reraise=True, + before_sleep=log_retry_attempt, + ) def get_new_session(self) -> ContextManager[Session]: """ Creates a new Session to be used for each task invocation. @@ -36,8 +64,10 @@ def get_new_session(self) -> ContextManager[Session]: if self._task_engine is None: self._task_engine = get_db_engine( config=CONFIG, - pool_size=CONFIG.database.task_engine_pool_size, - max_overflow=CONFIG.database.task_engine_max_overflow, + keepalives_idle=CONFIG.database.task_engine_keepalives_idle, + keepalives_interval=CONFIG.database.task_engine_keepalives_interval, + keepalives_count=CONFIG.database.task_engine_keepalives_count, + disable_pooling=True, ) # same for the sessionmaker diff --git a/src/fides/config/database_settings.py b/src/fides/config/database_settings.py index 9698050cb1..9d0b1a4258 100644 --- a/src/fides/config/database_settings.py +++ b/src/fides/config/database_settings.py @@ -31,6 +31,18 @@ class DatabaseSettings(FidesSettings): default=50, description="Number of additional 'overflow' concurrent database connections Fides will use for API requests if the pool reaches the limit. These overflow connections are discarded afterwards and not maintained.", ) + api_engine_keepalives_idle: int = Field( + default=30, + description="Number of seconds of inactivity before the client sends a TCP keepalive packet to verify the database connection is still alive.", + ) + api_engine_keepalives_interval: int = Field( + default=10, + description="Number of seconds between TCP keepalive retries if the initial keepalive packet receives no response. These are client-side retries.", + ) + api_engine_keepalives_count: int = Field( + default=5, + description="Maximum number of TCP keepalive retries before the client considers the connection dead and closes it.", + ) db: str = Field( default="default_db", description="The name of the application database." ) @@ -61,6 +73,18 @@ class DatabaseSettings(FidesSettings): default=50, description="Number of additional 'overflow' concurrent database connections Fides will use for executing privacy request tasks, either locally or on each worker, if the pool reaches the limit. These overflow connections are discarded afterwards and not maintained.", ) + task_engine_keepalives_idle: int = Field( + default=30, + description="Number of seconds of inactivity before the client sends a TCP keepalive packet to verify the database connection is still alive.", + ) + task_engine_keepalives_interval: int = Field( + default=10, + description="Number of seconds between TCP keepalive retries if the initial keepalive packet receives no response. These are client-side retries.", + ) + task_engine_keepalives_count: int = Field( + default=5, + description="Maximum number of TCP keepalive retries before the client considers the connection dead and closes it.", + ) test_db: str = Field( default="default_test_db", description="Used instead of the 'db' value when the FIDES_TEST_MODE environment variable is set to True. Avoids overwriting production data.", diff --git a/tests/ops/tasks/test_celery.py b/tests/ops/tasks/test_celery.py index eff9117fe5..8f53a026be 100644 --- a/tests/ops/tasks/test_celery.py +++ b/tests/ops/tasks/test_celery.py @@ -1,24 +1,7 @@ -# pylint: disable=protected-access -import pytest -from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session -from sqlalchemy.pool import QueuePool - -from fides.api.tasks import DatabaseTask, _create_celery +from fides.api.tasks import _create_celery from fides.config import CONFIG, CelerySettings, get_config -@pytest.fixture -def mock_config_changed_db_engine_settings(): - pool_size = CONFIG.database.task_engine_pool_size - CONFIG.database.task_engine_pool_size = pool_size + 5 - max_overflow = CONFIG.database.task_engine_max_overflow - CONFIG.database.task_engine_max_overflow = max_overflow + 5 - yield - CONFIG.database.task_engine_pool_size = pool_size - CONFIG.database.task_engine_max_overflow = max_overflow - - def test_create_task(celery_session_app, celery_session_worker): @celery_session_app.task def multiply(x, y): @@ -70,21 +53,3 @@ def test_celery_config_override() -> None: celery_app = _create_celery(config=config) assert celery_app.conf["event_queue_prefix"] == "overridden_fides_worker" assert celery_app.conf["task_default_queue"] == "overridden_fides" - - -@pytest.mark.parametrize( - "config_fixture", [None, "mock_config_changed_db_engine_settings"] -) -def test_get_task_session(config_fixture, request): - if config_fixture is not None: - request.getfixturevalue( - config_fixture - ) # used to invoke config fixture if provided - pool_size = CONFIG.database.task_engine_pool_size - max_overflow = CONFIG.database.task_engine_max_overflow - t = DatabaseTask() - session: Session = t.get_new_session() - engine: Engine = session.get_bind() - pool: QueuePool = engine.pool - assert pool.size() == pool_size - assert pool._max_overflow == max_overflow diff --git a/tests/ops/tasks/test_database_task.py b/tests/ops/tasks/test_database_task.py new file mode 100644 index 0000000000..0076f16f84 --- /dev/null +++ b/tests/ops/tasks/test_database_task.py @@ -0,0 +1,78 @@ +# pylint: disable=protected-access + +from unittest import mock + +import pytest +from sqlalchemy.engine import Engine +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import Session +from sqlalchemy.pool import NullPool + +from fides.api.tasks import NEW_SESSION_RETRIES, DatabaseTask +from fides.config import CONFIG + + +class TestDatabaseTask: + @pytest.fixture + def mock_config_changed_db_engine_settings(self): + pool_size = CONFIG.database.task_engine_pool_size + CONFIG.database.task_engine_pool_size = pool_size + 5 + max_overflow = CONFIG.database.task_engine_max_overflow + CONFIG.database.task_engine_max_overflow = max_overflow + 5 + yield + CONFIG.database.task_engine_pool_size = pool_size + CONFIG.database.task_engine_max_overflow = max_overflow + + @pytest.fixture + def recovering_session_maker(self): + """Fixture that fails twice then succeeds""" + mock_session = mock.Mock() + mock_maker = mock.Mock() + mock_maker.side_effect = [ + OperationalError("connection failed", None, None), + OperationalError("connection failed", None, None), + mock_session, + ] + return mock_maker, mock_session + + @pytest.fixture + def always_failing_session_maker(self): + """Fixture that always fails with OperationalError""" + mock_maker = mock.Mock() + mock_maker.side_effect = OperationalError("connection failed", None, None) + return mock_maker + + @pytest.mark.parametrize( + "config_fixture", [None, "mock_config_changed_db_engine_settings"] + ) + def test_get_task_session(self, config_fixture, request): + if config_fixture is not None: + request.getfixturevalue( + config_fixture + ) # used to invoke config fixture if provided + pool_size = CONFIG.database.task_engine_pool_size + max_overflow = CONFIG.database.task_engine_max_overflow + t = DatabaseTask() + session: Session = t.get_new_session() + engine: Engine = session.get_bind() + assert isinstance(engine.pool, NullPool) + + def test_retry_on_operational_error(self, recovering_session_maker): + """Test that session creation retries on OperationalError""" + + mock_maker, mock_session = recovering_session_maker + + task = DatabaseTask() + with mock.patch.object(task, "_sessionmaker", mock_maker): + session = task.get_new_session() + assert session == mock_session + assert mock_maker.call_count == 3 + + def test_max_retries_exceeded(mock_db_task, always_failing_session_maker): + """Test that retries stop after max attempts""" + task = DatabaseTask() + with mock.patch.object(task, "_sessionmaker", always_failing_session_maker): + with pytest.raises(OperationalError): + with task.get_new_session(): + pass + assert always_failing_session_maker.call_count == NEW_SESSION_RETRIES