diff --git a/backend/onyx/background/celery/celery_redis.py b/backend/onyx/background/celery/celery_redis.py index d438e5957ff..213388ac7c4 100644 --- a/backend/onyx/background/celery/celery_redis.py +++ b/backend/onyx/background/celery/celery_redis.py @@ -3,12 +3,54 @@ from typing import Any from typing import cast +from celery import Celery from redis import Redis from onyx.background.celery.configs.base import CELERY_SEPARATOR from onyx.configs.constants import OnyxCeleryPriority +def celery_get_unacked_length(r: Redis) -> int: + """Checking the unacked queue is useful because a non-zero length tells us there + may be prefetched tasks. + + There can be other tasks in here besides indexing tasks, so this is mostly useful + just to see if the task count is non zero. + + ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html + """ + length = cast(int, r.hlen("unacked")) + return length + + +def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]: + """Gets the set of task id's matching the given queue in the unacked hash. + + Unacked entries belonging to the indexing queue are "prefetched", so this gives + us crucial visibility as to what tasks are in that state. + """ + tasks: set[str] = set() + + for _, v in r.hscan_iter("unacked"): + v_bytes = cast(bytes, v) + v_str = v_bytes.decode("utf-8") + task = json.loads(v_str) + + task_description = task[0] + task_queue = task[2] + + if task_queue != queue: + continue + + task_id = task_description.get("headers", {}).get("id") + if not task_id: + continue + + # if the queue matches and we see the task_id, add it + tasks.add(task_id) + return tasks + + def celery_get_queue_length(queue: str, r: Redis) -> int: """This is a redis specific way to get the length of a celery queue. It is priority aware and knows how to count across the multiple redis lists @@ -47,3 +89,74 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int: return True return False + + +def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]: + """Returns a list of current workers containing name_filter, or all workers if + name_filter is None. + + We've empirically discovered that the celery inspect API is potentially unstable + and may hang or return empty results when celery is under load. Suggest using this + more to debug and troubleshoot than in production code. + """ + worker_names: list[str] = [] + + # filter for and create an indexing specific inspect object + inspect = app.control.inspect() + workers: dict[str, Any] = inspect.ping() # type: ignore + if workers: + for worker_name in list(workers.keys()): + # if the name filter not set, return all worker names + if not name_filter: + worker_names.append(worker_name) + continue + + # if the name filter is set, return only worker names that contain the name filter + if name_filter not in worker_name: + continue + + worker_names.append(worker_name) + + return worker_names + + +def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]: + """Returns a list of reserved tasks on the specified workers. + + We've empirically discovered that the celery inspect API is potentially unstable + and may hang or return empty results when celery is under load. Suggest using this + more to debug and troubleshoot than in production code. + """ + reserved_task_ids: set[str] = set() + + inspect = app.control.inspect(destination=worker_names) + + # get the list of reserved tasks + reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore + if reserved_tasks: + for _, task_list in reserved_tasks.items(): + for task in task_list: + reserved_task_ids.add(task["id"]) + + return reserved_task_ids + + +def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]: + """Returns a list of active tasks on the specified workers. + + We've empirically discovered that the celery inspect API is potentially unstable + and may hang or return empty results when celery is under load. Suggest using this + more to debug and troubleshoot than in production code. + """ + active_task_ids: set[str] = set() + + inspect = app.control.inspect(destination=worker_names) + + # get the list of reserved tasks + active_tasks: dict[str, list] | None = inspect.active() # type: ignore + if active_tasks: + for _, task_list in active_tasks.items(): + for task in task_list: + active_task_ids.add(task["id"]) + + return active_task_ids diff --git a/backend/onyx/background/celery/configs/indexing.py b/backend/onyx/background/celery/configs/indexing.py index 1c6a2b662f5..dcfb2c17037 100644 --- a/backend/onyx/background/celery/configs/indexing.py +++ b/backend/onyx/background/celery/configs/indexing.py @@ -16,6 +16,11 @@ task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late +# Indexing worker specific ... this lets us track the transition to STARTED in redis +# We don't currently rely on this but it has the potential to be useful and +# indexing tasks are not high volume +task_track_started = True + worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 26ff7d21127..e2e2e5631fb 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -3,7 +3,6 @@ from datetime import timezone from http import HTTPStatus from time import sleep -from typing import Any import redis import sentry_sdk @@ -18,6 +17,7 @@ from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_find_task +from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.indexing.job_client import SimpleJobClient from onyx.background.indexing.run_indexing import run_indexing_entrypoint from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP @@ -29,6 +29,7 @@ from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks +from onyx.configs.constants import OnyxRedisSignals from onyx.db.connector import mark_ccpair_with_indexing_trigger from onyx.db.connector_credential_pair import fetch_connector_credential_pairs from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id @@ -175,7 +176,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: # we need to use celery's redis client to access its redis data # (which lives on a different db number) - # redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore + redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore lock_beat: RedisLock = redis_client.lock( OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK, @@ -318,23 +319,20 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: attempt.id, db_session, failure_reason=failure_reason ) - # rkuo: The following code logically appears to work, but the celery inspect code may be unstable - # turning off for the moment to see if it helps cloud stability - # we want to run this less frequently than the overall task - # if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES): - # # clear any indexing fences that don't have associated celery tasks in progress - # # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), - # # or be currently executing - # try: - # task_logger.info("Validating indexing fences...") - # validate_indexing_fences( - # tenant_id, self.app, redis_client, redis_client_celery, lock_beat - # ) - # except Exception: - # task_logger.exception("Exception while validating indexing fences") - - # redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60) + if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES): + # clear any indexing fences that don't have associated celery tasks in progress + # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), + # or be currently executing + try: + task_logger.info("Validating indexing fences...") + validate_indexing_fences( + tenant_id, self.app, redis_client, redis_client_celery, lock_beat + ) + except Exception: + task_logger.exception("Exception while validating indexing fences") + + redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60) except SoftTimeLimitExceeded: task_logger.info( @@ -353,7 +351,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) time_elapsed = time.monotonic() - time_start - task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}") + task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}") return tasks_created @@ -364,46 +362,9 @@ def validate_indexing_fences( r_celery: Redis, lock_beat: RedisLock, ) -> None: - reserved_indexing_tasks: set[str] = set() - active_indexing_tasks: set[str] = set() - indexing_worker_names: list[str] = [] - - # filter for and create an indexing specific inspect object - inspect = celery_app.control.inspect() - workers: dict[str, Any] = inspect.ping() # type: ignore - if not workers: - raise ValueError("No workers found!") - - for worker_name in list(workers.keys()): - if "indexing" in worker_name: - indexing_worker_names.append(worker_name) - - if len(indexing_worker_names) == 0: - raise ValueError("No indexing workers found!") - - inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names) - - # NOTE: each dict entry is a map of worker name to a list of tasks - # we want sets for reserved task and active task id's to optimize - # subsequent validation lookups - - # get the list of reserved tasks - reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore - if reserved_tasks is None: - raise ValueError("inspect_indexing.reserved() returned None!") - - for _, task_list in reserved_tasks.items(): - for task in task_list: - reserved_indexing_tasks.add(task["id"]) - - # get the list of active tasks - active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore - if active_tasks is None: - raise ValueError("inspect_indexing.active() returned None!") - - for _, task_list in active_tasks.items(): - for task in task_list: - active_indexing_tasks.add(task["id"]) + reserved_indexing_tasks = celery_get_unacked_task_ids( + "connector_indexing", r_celery + ) # validate all existing indexing jobs for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"): @@ -413,7 +374,6 @@ def validate_indexing_fences( tenant_id, key_bytes, reserved_indexing_tasks, - active_indexing_tasks, r_celery, db_session, ) @@ -424,7 +384,6 @@ def validate_indexing_fence( tenant_id: str | None, key_bytes: bytes, reserved_tasks: set[str], - active_tasks: set[str], r_celery: Redis, db_session: Session, ) -> None: @@ -434,11 +393,15 @@ def validate_indexing_fence( gives the help. How this works: - 1. Active signal is renewed with a 5 minute TTL - 1.1 When the fence is created + 1. This function renews the active signal with a 5 minute TTL under the following conditions 1.2. When the task is seen in the redis queue - 1.3. When the task is seen in the reserved or active list for a worker - 2. The TTL allows us to get through the transitions on fence startup + 1.3. When the task is seen in the reserved / prefetched list + + 2. Externally, the active signal is renewed when: + 2.1. The fence is created + 2.2. The indexing watchdog checks the spawned task. + + 3. The TTL allows us to get through the transitions on fence startup and when the task starts executing. More TTL clarification: it is seemingly impossible to exactly query Celery for @@ -466,6 +429,8 @@ def validate_indexing_fence( redis_connector = RedisConnector(tenant_id, cc_pair_id) redis_connector_index = redis_connector.new_index(search_settings_id) + + # check to see if the fence/payload exists if not redis_connector_index.fenced: return @@ -501,18 +466,14 @@ def validate_indexing_fence( redis_connector_index.set_active() return - if payload.celery_task_id in active_tasks: - # the celery task is active (aka currently executing) - redis_connector_index.set_active() - return - # we may want to enable this check if using the active task list somehow isn't good enough # if redis_connector_index.generator_locked(): # logger.info(f"{payload.celery_task_id} is currently executing.") - # we didn't find any direct indication that associated celery tasks exist, but they still might be there - # due to gaps in our ability to check states during transitions - # Rely on the active signal (which has a duration that allows us to bridge those gaps) + # if we get here, we didn't find any direct indication that the associated celery tasks exist, + # but they still might be there due to gaps in our ability to check states during transitions + # Checking the active signal safeguards us against these transition periods + # (which has a duration that allows us to bridge those gaps) if redis_connector_index.active(): return @@ -795,6 +756,52 @@ def connector_indexing_proxy_task( while True: sleep(5) + # renew active signal + redis_connector_index.set_active() + + # if the job is done, clean up and break + if job.done(): + if job.status == "error": + ignore_exitcode = False + + exit_code: int | None = None + if job.process: + exit_code = job.process.exitcode + + # seeing odd behavior where spawned tasks usually return exit code 1 in the cloud, + # even though logging clearly indicates that they completed successfully + # to work around this, we ignore the job error state if the completion signal is OK + status_int = redis_connector_index.get_completion() + if status_int: + status_enum = HTTPStatus(status_int) + if status_enum == HTTPStatus.OK: + ignore_exitcode = True + + if ignore_exitcode: + task_logger.warning( + "Indexing watchdog - spawned task has non-zero exit code " + "but completion signal is OK. Continuing...: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " + f"exit_code={exit_code}" + ) + else: + task_logger.error( + "Indexing watchdog - spawned task exceptioned: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id} " + f"exit_code={exit_code} " + f"error={job.exception()}" + ) + + job.release() + break + + # if a termination signal is detected, clean up and break if self.request.id and redis_connector_index.terminating(self.request.id): task_logger.warning( "Indexing watchdog - termination signal detected: " @@ -821,75 +828,33 @@ def connector_indexing_proxy_task( f"search_settings={search_settings_id}" ) - job.cancel() - + job.cancel() break - if not job.done(): - # if the spawned task is still running, restart the check once again - # if the index attempt is not in a finished status - try: - with get_session_with_tenant(tenant_id) as db_session: - index_attempt = get_index_attempt( - db_session=db_session, index_attempt_id=index_attempt_id - ) - - if not index_attempt: - continue - - if not index_attempt.is_finished(): - continue - except Exception: - # if the DB exceptioned, just restart the check. - # polling the index attempt status doesn't need to be strongly consistent - logger.exception( - "Indexing watchdog - transient exception looking up index attempt: " - f"attempt={index_attempt_id} " - f"tenant={tenant_id} " - f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id}" - ) - continue - - if job.status == "error": - ignore_exitcode = False - - exit_code: int | None = None - if job.process: - exit_code = job.process.exitcode - - # seeing odd behavior where spawned tasks usually return exit code 1 in the cloud, - # even though logging clearly indicates that they completed successfully - # to work around this, we ignore the job error state if the completion signal is OK - status_int = redis_connector_index.get_completion() - if status_int: - status_enum = HTTPStatus(status_int) - if status_enum == HTTPStatus.OK: - ignore_exitcode = True - - if ignore_exitcode: - task_logger.warning( - "Indexing watchdog - spawned task has non-zero exit code " - "but completion signal is OK. Continuing...: " - f"attempt={index_attempt_id} " - f"tenant={tenant_id} " - f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id} " - f"exit_code={exit_code}" - ) - else: - task_logger.error( - "Indexing watchdog - spawned task exceptioned: " - f"attempt={index_attempt_id} " - f"tenant={tenant_id} " - f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id} " - f"exit_code={exit_code} " - f"error={job.exception()}" + # if the spawned task is still running, restart the check once again + # if the index attempt is not in a finished status + try: + with get_session_with_tenant(tenant_id) as db_session: + index_attempt = get_index_attempt( + db_session=db_session, index_attempt_id=index_attempt_id ) - job.release() - break + if not index_attempt: + continue + + if not index_attempt.is_finished(): + continue + except Exception: + # if the DB exceptioned, just restart the check. + # polling the index attempt status doesn't need to be strongly consistent + logger.exception( + "Indexing watchdog - transient exception looking up index attempt: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + continue task_logger.info( f"Indexing watchdog - finished: attempt={index_attempt_id} " diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index ba59ff4b11a..3c3316b904b 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -20,6 +20,7 @@ from onyx.access.access import get_access_for_document from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_get_queue_length +from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT @@ -766,31 +767,34 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery ) + prefetched = celery_get_unacked_task_ids( + OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery + ) + task_logger.info( f"Queue lengths: celery={n_celery} " f"indexing={n_indexing} " + f"indexing_prefetched={len(prefetched)} " f"sync={n_sync} " f"deletion={n_deletion} " f"pruning={n_pruning} " f"permissions_sync={n_permissions_sync} " ) + # scan and monitor activity to completion lock_beat.reacquire() if r.exists(RedisConnectorCredentialPair.get_fence_key()): monitor_connector_taskset(r) - lock_beat.reacquire() for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"): lock_beat.reacquire() monitor_connector_deletion_taskset(tenant_id, key_bytes, r) - lock_beat.reacquire() for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): lock_beat.reacquire() with get_session_with_tenant(tenant_id) as db_session: monitor_document_set_taskset(tenant_id, key_bytes, r, db_session) - lock_beat.reacquire() for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): lock_beat.reacquire() monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback( @@ -801,28 +805,21 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: with get_session_with_tenant(tenant_id) as db_session: monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session) - lock_beat.reacquire() for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"): lock_beat.reacquire() with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session) - lock_beat.reacquire() for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"): lock_beat.reacquire() with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session) - lock_beat.reacquire() for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): lock_beat.reacquire() with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session) - # uncomment for debugging if needed - # r_celery = celery_app.broker_connection().channel().client - # length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery) - # task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}") except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -832,7 +829,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: lock_beat.release() time_elapsed = time.monotonic() - time_start - task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}") + task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}") return True