Skip to content

Commit

Permalink
use redis exclusively with active signal renewal in more places to pe…
Browse files Browse the repository at this point in the history
…rform indexing fence validation (#3517)

Co-authored-by: Richard Kuo (Danswer) <[email protected]>
  • Loading branch information
rkuo-danswer and Richard Kuo (Danswer) authored Dec 20, 2024
1 parent aa4bfa2 commit a187aa5
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 149 deletions.
113 changes: 113 additions & 0 deletions backend/onyx/background/celery/celery_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,54 @@
from typing import Any
from typing import cast

from celery import Celery
from redis import Redis

from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.constants import OnyxCeleryPriority


def celery_get_unacked_length(r: Redis) -> int:
"""Checking the unacked queue is useful because a non-zero length tells us there
may be prefetched tasks.
There can be other tasks in here besides indexing tasks, so this is mostly useful
just to see if the task count is non zero.
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
"""
length = cast(int, r.hlen("unacked"))
return length


def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()

for _, v in r.hscan_iter("unacked"):
v_bytes = cast(bytes, v)
v_str = v_bytes.decode("utf-8")
task = json.loads(v_str)

task_description = task[0]
task_queue = task[2]

if task_queue != queue:
continue

task_id = task_description.get("headers", {}).get("id")
if not task_id:
continue

# if the queue matches and we see the task_id, add it
tasks.add(task_id)
return tasks


def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
Expand Down Expand Up @@ -47,3 +89,74 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return True

return False


def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
worker_names: list[str] = []

# filter for and create an indexing specific inspect object
inspect = app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if workers:
for worker_name in list(workers.keys()):
# if the name filter not set, return all worker names
if not name_filter:
worker_names.append(worker_name)
continue

# if the name filter is set, return only worker names that contain the name filter
if name_filter not in worker_name:
continue

worker_names.append(worker_name)

return worker_names


def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of reserved tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
reserved_task_ids: set[str] = set()

inspect = app.control.inspect(destination=worker_names)

# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
if reserved_tasks:
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_task_ids.add(task["id"])

return reserved_task_ids


def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of active tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
active_task_ids: set[str] = set()

inspect = app.control.inspect(destination=worker_names)

# get the list of reserved tasks
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
if active_tasks:
for _, task_list in active_tasks.items():
for task in task_list:
active_task_ids.add(task["id"])

return active_task_ids
5 changes: 5 additions & 0 deletions backend/onyx/background/celery/configs/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

# Indexing worker specific ... this lets us track the transition to STARTED in redis
# We don't currently rely on this but it has the potential to be useful and
# indexing tasks are not high volume
task_track_started = True

worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1
Loading

0 comments on commit a187aa5

Please sign in to comment.