Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slackbot optimization #3696

Merged
merged 11 commits into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 152 additions & 42 deletions backend/onyx/onyxbot/slack/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
Expand Down Expand Up @@ -122,6 +123,9 @@ def __init__(self) -> None:
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}

# Store Redis lock objects here so we can release them properly
self.redis_locks: Dict[str | None, Lock] = {}

self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
Expand Down Expand Up @@ -159,10 +163,15 @@ def acquire_tenants_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()

# After we finish acquiring and managing Slack bots,
# set the gauge to the number of active tenants (those with Slack bots).
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
logger.debug(
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
)
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
Expand All @@ -171,25 +180,31 @@ def heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)

def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None:
"""
- If the tokens are missing or empty, close the socket client and remove them.
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)

# If the tokens are not set, we need to close the socket client and delete the tokens
# for the tenant and app
# If the tokens are missing or empty, close the socket client and remove them.
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
Expand All @@ -204,9 +219,10 @@ def _manage_clients_per_tenant(
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
)
else:
# Warm up the model if needed
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
Expand All @@ -217,77 +233,168 @@ def _manage_clients_per_tenant(

self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens

# Close any existing connection first
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())

self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)

def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()

for tenant_id in tenant_ids:
"""
pablonyx marked this conversation as resolved.
Show resolved Hide resolved
- Attempt to acquire a Redis lock for each tenant.
- If acquired, check if that tenant actually has Slack bots.
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()

# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
continue

# Already acquired in a previous loop iteration?
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue

# Respect max tenant limit per pod
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
)
break

redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id
acquired = redis_client.set(
OnyxRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
# Acquire a Redis lock (non-blocking)
rlock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
)
if not acquired and not DEV_MODE:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
lock_acquired = rlock.acquire(blocking=False)

logger.debug(f"Acquired lock for tenant {tenant_id}")
if not lock_acquired and not DEV_MODE:
logger.debug(
f"Another pod holds the lock for tenant {tenant_id}, skipping."
)
continue

self.tenant_ids.add(tenant_id)
if lock_acquired:
logger.debug(f"Acquired lock for tenant {tenant_id}.")
self.redis_locks[tenant_id] = rlock
else:
# DEV_MODE will skip the lock acquisition guard
logger.debug(
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
)

for tenant_id in self.tenant_ids:
# Now check if this tenant actually has Slack bots
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
bots: list[SlackBot] = []
try:
bots = fetch_slack_bots(db_session=db_session)
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
)

if bots:
# Mark as active tenant
self.tenant_ids.add(tenant_id)
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
else:
# If no Slack bots, release lock immediately (unless in DEV_MODE)
if lock_acquired and not DEV_MODE:
rlock.release()
del self.redis_locks[tenant_id]
logger.debug(
f"No Slack bots for tenant {tenant_id}; lock released (if held)."
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

# 2) Make sure tenants we're handling still have Slack bots
for tenant_id in list(self.tenant_ids):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
redis_client = get_redis_client(tenant_id=tenant_id)

try:
with get_session_with_tenant(tenant_id) as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if (tenant_id, bot.id) in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
del self.socket_clients[tenant_id, bot.id]
del self.slack_bot_tokens[tenant_id, bot.id]
# No Slackbot tokens, pass (and remove below)
bots = []
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
bots = []

if not bots:
logger.info(
f"Tenant {tenant_id} no longer has Slack bots. Removing."
)
self._remove_tenant(tenant_id)

# NOTE: We release the lock here (in the same scope it was acquired)
if tenant_id in self.redis_locks and not DEV_MODE:
try:
self.redis_locks[tenant_id].release()
del self.redis_locks[tenant_id]
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Error releasing lock for tenant {tenant_id}: {e}"
)
else:
# Manage or reconnect Slack bot sockets
for bot in bots:
self._manage_clients_per_tenant(
pablonyx marked this conversation as resolved.
Show resolved Hide resolved
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

def _remove_tenant(self, tenant_id: str | None) -> None:
pablonyx marked this conversation as resolved.
Show resolved Hide resolved
"""
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
if t_id == tenant_id:
asyncio.run(client.close())
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
)

# Remove from active set
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)

def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
Expand Down Expand Up @@ -315,14 +422,15 @@ def start_socket_client(
)
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)

def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
asyncio.run(client.close())
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
Expand All @@ -340,17 +448,19 @@ def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()

# Release locks for all tenants
# Release locks for all tenants we currently hold
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")

# Wait for background threads to finish (with timeout)
for tenant_id in list(self.tenant_ids):
if tenant_id in self.redis_locks:
try:
self.redis_locks[tenant_id].release()
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
finally:
del self.redis_locks[tenant_id]

# Wait for background threads to finish (with a timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)
Expand Down
Loading