diff --git a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py index 799644adb17..8c7647468d0 100644 --- a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py @@ -44,11 +44,11 @@ def check_for_connector_deletion_task( timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return None + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + try: # collect cc_pair_ids cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index 028a9e45df0..20ad0a07565 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -102,11 +102,11 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return None + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + try: # get all cc pairs that need to be synced cc_pair_ids_to_sync: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index bad23c120ff..238e147c9af 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -102,11 +102,11 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return None + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + try: cc_pair_ids_to_sync: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 366f8cc23a4..8443bb1f079 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -63,6 +63,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndexPayload from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import redis_lock_dump +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import global_version from shared_configs.configs import INDEXING_MODEL_SERVER_HOST @@ -204,6 +205,10 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: """a lightweight task used to kick off indexing tasks. Occcasionally does some validation of existing state to clear up error conditions""" + debug_tenants = { + "tenant_i-043470d740845ec56", + "tenant_82b497ce-88aa-4fbd-841a-92cae43529c8", + } time_start = time.monotonic() tasks_created = 0 @@ -219,11 +224,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return None + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + try: locked = True # check for search settings swap @@ -246,15 +251,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) # gather cc_pair_ids + lock_beat.reacquire() cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: - lock_beat.reacquire() cc_pairs = fetch_connector_credential_pairs(db_session) for cc_pair_entry in cc_pairs: cc_pair_ids.append(cc_pair_entry.id) # kick off index attempts for cc_pair_id in cc_pair_ids: + # debugging logic - remove after we're done + if tenant_id in debug_tenants: + ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK) + task_logger.info( + f"check_for_indexing cc_pair lock: " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"ttl={ttl}" + ) + lock_beat.reacquire() redis_connector = RedisConnector(tenant_id, cc_pair_id) @@ -331,14 +346,33 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) tasks_created += 1 + # debugging logic - remove after we're done + if tenant_id in debug_tenants: + ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK) + task_logger.info( + f"check_for_indexing unfenced lock: " + f"tenant={tenant_id} " + f"ttl={ttl}" + ) + + lock_beat.reacquire() + # Fail any index attempts in the DB that don't have fences # This shouldn't ever happen! with get_session_with_tenant(tenant_id) as db_session: - lock_beat.reacquire() unfenced_attempt_ids = get_unfenced_index_attempt_ids( db_session, redis_client ) for attempt_id in unfenced_attempt_ids: + # debugging logic - remove after we're done + if tenant_id in debug_tenants: + ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK) + task_logger.info( + f"check_for_indexing unfenced attempt id lock: " + f"tenant={tenant_id} " + f"ttl={ttl}" + ) + lock_beat.reacquire() attempt = get_index_attempt(db_session, attempt_id) @@ -356,9 +390,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: attempt.id, db_session, failure_reason=failure_reason ) + # debugging logic - remove after we're done + if tenant_id in debug_tenants: + ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK) + task_logger.info( + f"check_for_indexing validate fences lock: " + f"tenant={tenant_id} " + f"ttl={ttl}" + ) + + lock_beat.reacquire() # we want to run this less frequently than the overall task if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES): - lock_beat.reacquire() # clear any indexing fences that don't have associated celery tasks in progress # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), # or be currently executing @@ -370,7 +413,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: task_logger.exception("Exception while validating indexing fences") redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60) - except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -405,7 +447,9 @@ def validate_indexing_fences( ) # validate all existing indexing jobs - for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"): + for key_bytes in r.scan_iter( + RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): lock_beat.reacquire() with get_session_with_tenant(tenant_id) as db_session: validate_indexing_fence( diff --git a/backend/onyx/background/celery/tasks/pruning/tasks.py b/backend/onyx/background/celery/tasks/pruning/tasks.py index 920bc44cdf6..a1e891365f3 100644 --- a/backend/onyx/background/celery/tasks/pruning/tasks.py +++ b/backend/onyx/background/celery/tasks/pruning/tasks.py @@ -89,11 +89,11 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None: timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return None + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + try: cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index 90a5f23702a..c00bac354f3 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -4,6 +4,7 @@ from datetime import datetime from datetime import timezone from http import HTTPStatus +from typing import Any from typing import cast import httpx @@ -26,6 +27,7 @@ from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT from onyx.configs.app_configs import JOB_TIMEOUT +from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask @@ -70,6 +72,7 @@ from onyx.redis.redis_document_set import RedisDocumentSet from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import redis_lock_dump +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT from onyx.redis.redis_usergroup import RedisUserGroup from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation @@ -103,14 +106,14 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return None + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + try: with get_session_with_tenant(tenant_id) as db_session: try_generate_stale_document_sync_tasks( - self.app, db_session, r, lock_beat, tenant_id + self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id ) # region document set scan @@ -185,6 +188,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No def try_generate_stale_document_sync_tasks( celery_app: Celery, + max_tasks: int, db_session: Session, r: Redis, lock_beat: RedisLock, @@ -215,11 +219,16 @@ def try_generate_stale_document_sync_tasks( # rkuo: we could technically sync all stale docs in one big pass. # but I feel it's more understandable to group the docs by cc_pair total_tasks_generated = 0 + tasks_remaining = max_tasks cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: + lock_beat.reacquire() + rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id) rc.set_skip_docs(docs_to_skip) - result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id) + result = rc.generate_tasks( + tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id + ) if result is None: continue @@ -233,10 +242,19 @@ def try_generate_stale_document_sync_tasks( ) total_tasks_generated += result[0] + tasks_remaining -= result[0] + if tasks_remaining <= 0: + break - task_logger.info( - f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}" - ) + if tasks_remaining <= 0: + task_logger.info( + f"RedisConnector.generate_tasks reached the task generation limit: " + f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}" + ) + else: + task_logger.info( + f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}" + ) r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated) return total_tasks_generated @@ -275,7 +293,9 @@ def try_generate_document_set_sync_tasks( ) # Add all documents that need to be updated into the queue - result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id) + result = rds.generate_tasks( + VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id + ) if result is None: return None @@ -330,7 +350,9 @@ def try_generate_user_group_sync_tasks( task_logger.info( f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}" ) - result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id) + result = rug.generate_tasks( + VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id + ) if result is None: return None @@ -752,7 +774,7 @@ def monitor_ccpair_indexing_taskset( @shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True) -def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: +def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None: """This is a celery beat task that monitors and finalizes metadata sync tasksets. It scans for fence values and then gets the counts of any associated tasksets. If the count is 0, that means all tasks finished and we should clean up. @@ -766,7 +788,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: time_start = time.monotonic() - timings: dict[str, float] = {} + timings: dict[str, Any] = {} timings["start"] = time_start r = get_redis_client(tenant_id=tenant_id) @@ -776,16 +798,15 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) - try: - # prevent overlapping tasks - if not lock_beat.acquire(blocking=False): - task_logger.info("monitor_vespa_sync exiting due to overlap") - return False + # prevent overlapping tasks + if not lock_beat.acquire(blocking=False): + return None + try: # print current queue lengths phase_start = time.monotonic() # we don't need every tenant polling redis for this info. - if not MULTI_TENANT or random.randint(1, 100) == 100: + if not MULTI_TENANT or random.randint(1, 10) == 10: r_celery = self.app.broker_connection().channel().client # type: ignore n_celery = celery_get_queue_length("celery", r_celery) n_indexing = celery_get_queue_length( @@ -826,6 +847,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: f"permissions_upsert={n_permissions_upsert} " ) timings["queues"] = time.monotonic() - phase_start + timings["queues_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) # scan and monitor activity to completion phase_start = time.monotonic() @@ -833,24 +855,37 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: if r.exists(RedisConnectorCredentialPair.get_fence_key()): monitor_connector_taskset(r) timings["connector"] = time.monotonic() - phase_start + timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) phase_start = time.monotonic() - for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"): - lock_beat.reacquire() + lock_beat.reacquire() + for key_bytes in r.scan_iter( + RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): monitor_connector_deletion_taskset(tenant_id, key_bytes, r) + lock_beat.reacquire() timings["connector_deletion"] = time.monotonic() - phase_start + timings["connector_deletion_ttl"] = r.ttl( + OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK + ) phase_start = time.monotonic() - for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - lock_beat.reacquire() + lock_beat.reacquire() + for key_bytes in r.scan_iter( + RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): with get_session_with_tenant(tenant_id) as db_session: monitor_document_set_taskset(tenant_id, key_bytes, r, db_session) - timings["document_set"] = time.monotonic() - phase_start + lock_beat.reacquire() + timings["documentset"] = time.monotonic() - phase_start + timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) phase_start = time.monotonic() - for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - lock_beat.reacquire() + lock_beat.reacquire() + for key_bytes in r.scan_iter( + RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback( "onyx.background.celery.tasks.vespa.tasks", "monitor_usergroup_taskset", @@ -858,29 +893,45 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: ) with get_session_with_tenant(tenant_id) as db_session: monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session) + lock_beat.reacquire() timings["usergroup"] = time.monotonic() - phase_start + timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) phase_start = time.monotonic() - for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"): - lock_beat.reacquire() + lock_beat.reacquire() + for key_bytes in r.scan_iter( + RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session) + lock_beat.reacquire() timings["pruning"] = time.monotonic() - phase_start + timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) phase_start = time.monotonic() - for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"): - lock_beat.reacquire() + lock_beat.reacquire() + for key_bytes in r.scan_iter( + RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session) + lock_beat.reacquire() timings["indexing"] = time.monotonic() - phase_start + timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) phase_start = time.monotonic() - for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): - lock_beat.reacquire() + lock_beat.reacquire() + for key_bytes in r.scan_iter( + RedisConnectorPermissionSync.FENCE_PREFIX + "*", + count=SCAN_ITER_COUNT_DEFAULT, + ): with get_session_with_tenant(tenant_id) as db_session: monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session) + lock_beat.reacquire() timings["permissions"] = time.monotonic() - phase_start + timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) + except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -889,18 +940,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: if lock_beat.owned(): lock_beat.release() else: - t = timings task_logger.error( "monitor_vespa_sync - Lock not owned on completion: " f"tenant={tenant_id} " - f"queues={t.get('queues')} " - f"connector={t.get('connector')} " - f"connector_deletion={t.get('connector_deletion')} " - f"document_set={t.get('document_set')} " - f"usergroup={t.get('usergroup')} " - f"pruning={t.get('pruning')} " - f"indexing={t.get('indexing')} " - f"permissions={t.get('permissions')}" + f"timings={timings}" ) redis_lock_dump(lock_beat, r) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index e79378c63a7..c37384217b4 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -280,6 +280,11 @@ except ValueError: CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT +# The maximum number of tasks that can be queued up to sync to Vespa in a single pass +VESPA_SYNC_MAX_TASKS = 1024 + +DB_YIELD_PER_DEFAULT = 64 + ##### # Connector Configs ##### diff --git a/backend/onyx/redis/redis_connector_credential_pair.py b/backend/onyx/redis/redis_connector_credential_pair.py index 5c0501b9d75..0d53c2dc806 100644 --- a/backend/onyx/redis/redis_connector_credential_pair.py +++ b/backend/onyx/redis/redis_connector_credential_pair.py @@ -7,6 +7,7 @@ from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -65,12 +66,20 @@ def make_redis_syncing_key(doc_id: str) -> str: def generate_tasks( self, + max_tasks: int, celery_app: Celery, db_session: Session, redis_client: Redis, lock: RedisLock, tenant_id: str | None, ) -> tuple[int, int] | None: + """We can limit the number of tasks generated here, which is useful to prevent + one tenant from overwhelming the sync queue. + + This works because the dirty state of a document is in the DB, so more docs + get picked up after the limited set of tasks is complete. + """ + last_lock_time = time.monotonic() async_results = [] @@ -84,7 +93,7 @@ def generate_tasks( num_docs = 0 - for doc in db_session.scalars(stmt).yield_per(1): + for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc = cast(Document, doc) current_time = time.monotonic() if current_time - last_lock_time >= ( @@ -132,4 +141,7 @@ def generate_tasks( async_results.append(result) self.skip_docs.add(doc.id) + if len(async_results) >= max_tasks: + break + return len(async_results), num_docs diff --git a/backend/onyx/redis/redis_connector_delete.py b/backend/onyx/redis/redis_connector_delete.py index 1afe01e2696..b5285fb71ad 100644 --- a/backend/onyx/redis/redis_connector_delete.py +++ b/backend/onyx/redis/redis_connector_delete.py @@ -9,6 +9,7 @@ from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -98,7 +99,7 @@ def generate_tasks( stmt = construct_document_select_for_connector_credential_pair( cc_pair.connector_id, cc_pair.credential_id ) - for doc_temp in db_session.scalars(stmt).yield_per(1): + for doc_temp in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc: DbDocument = doc_temp current_time = time.monotonic() if current_time - last_lock_time >= ( diff --git a/backend/onyx/redis/redis_connector_doc_perm_sync.py b/backend/onyx/redis/redis_connector_doc_perm_sync.py index 846050a6ffe..62c98a66fad 100644 --- a/backend/onyx/redis/redis_connector_doc_perm_sync.py +++ b/backend/onyx/redis/redis_connector_doc_perm_sync.py @@ -13,6 +13,7 @@ from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT class RedisConnectorPermissionSyncPayload(BaseModel): @@ -68,7 +69,10 @@ def get_remaining(self) -> int: def get_active_task_count(self) -> int: """Count of active permission sync tasks""" count = 0 - for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): + for _ in self.redis.scan_iter( + RedisConnectorPermissionSync.FENCE_PREFIX + "*", + count=SCAN_ITER_COUNT_DEFAULT, + ): count += 1 return count diff --git a/backend/onyx/redis/redis_connector_ext_group_sync.py b/backend/onyx/redis/redis_connector_ext_group_sync.py index bbe539c3954..4d29ab5956a 100644 --- a/backend/onyx/redis/redis_connector_ext_group_sync.py +++ b/backend/onyx/redis/redis_connector_ext_group_sync.py @@ -7,6 +7,8 @@ from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT + class RedisConnectorExternalGroupSyncPayload(BaseModel): started: datetime | None @@ -63,7 +65,8 @@ def get_active_task_count(self) -> int: """Count of active external group syncing tasks""" count = 0 for _ in self.redis.scan_iter( - RedisConnectorExternalGroupSync.FENCE_PREFIX + "*" + RedisConnectorExternalGroupSync.FENCE_PREFIX + "*", + count=SCAN_ITER_COUNT_DEFAULT, ): count += 1 return count diff --git a/backend/onyx/redis/redis_connector_prune.py b/backend/onyx/redis/redis_connector_prune.py index 2e7f8214e1c..bbecc1b8cbf 100644 --- a/backend/onyx/redis/redis_connector_prune.py +++ b/backend/onyx/redis/redis_connector_prune.py @@ -12,6 +12,7 @@ from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT class RedisConnectorPrune: @@ -63,7 +64,9 @@ def get_remaining(self) -> int: def get_active_task_count(self) -> int: """Count of active pruning tasks""" count = 0 - for key in self.redis.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"): + for key in self.redis.scan_iter( + RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT + ): count += 1 return count diff --git a/backend/onyx/redis/redis_document_set.py b/backend/onyx/redis/redis_document_set.py index 1433cb04aed..aa219d6dd0d 100644 --- a/backend/onyx/redis/redis_document_set.py +++ b/backend/onyx/redis/redis_document_set.py @@ -8,6 +8,7 @@ from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -50,17 +51,21 @@ def payload(self) -> int | None: def generate_tasks( self, + max_tasks: int, celery_app: Celery, db_session: Session, redis_client: Redis, lock: RedisLock, tenant_id: str | None, ) -> tuple[int, int] | None: + """Max tasks is ignored for now until we can build the logic to mark the + document set up to date over multiple batches. + """ last_lock_time = time.monotonic() async_results = [] stmt = construct_document_select_by_docset(int(self._id), current_only=False) - for doc in db_session.scalars(stmt).yield_per(1): + for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc = cast(Document, doc) current_time = time.monotonic() if current_time - last_lock_time >= ( diff --git a/backend/onyx/redis/redis_object_helper.py b/backend/onyx/redis/redis_object_helper.py index 4c90236573b..34b301882f6 100644 --- a/backend/onyx/redis/redis_object_helper.py +++ b/backend/onyx/redis/redis_object_helper.py @@ -82,6 +82,7 @@ def get_id_from_task_id(task_id: str) -> str | None: @abstractmethod def generate_tasks( self, + max_tasks: int, celery_app: Celery, db_session: Session, redis_client: Redis, diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index 83f7d010376..e3617127a6e 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -29,6 +29,8 @@ logger = setup_logger() +SCAN_ITER_COUNT_DEFAULT = 4096 + class TenantRedis(redis.Redis): def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None: @@ -116,6 +118,7 @@ def __getattribute__(self, item: str) -> Any: "hexists", "hset", "hdel", + "ttl", ] # Regular methods that need simple prefixing if item == "scan_iter": diff --git a/backend/onyx/redis/redis_usergroup.py b/backend/onyx/redis/redis_usergroup.py index f7ee6199a3b..00981e9a1df 100644 --- a/backend/onyx/redis/redis_usergroup.py +++ b/backend/onyx/redis/redis_usergroup.py @@ -8,6 +8,7 @@ from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -51,12 +52,16 @@ def payload(self) -> int | None: def generate_tasks( self, + max_tasks: int, celery_app: Celery, db_session: Session, redis_client: Redis, lock: RedisLock, tenant_id: str | None, ) -> tuple[int, int] | None: + """Max tasks is ignored for now until we can build the logic to mark the + user group up to date over multiple batches. + """ last_lock_time = time.monotonic() async_results = [] @@ -73,7 +78,7 @@ def generate_tasks( return 0, 0 stmt = construct_document_select_by_usergroup(int(self._id)) - for doc in db_session.scalars(stmt).yield_per(1): + for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc = cast(Document, doc) current_time = time.monotonic() if current_time - last_lock_time >= (