diff --git a/backend/onyx/onyxbot/slack/listener.py b/backend/onyx/onyxbot/slack/listener.py index 7624b35c84d..106891d419e 100644 --- a/backend/onyx/onyxbot/slack/listener.py +++ b/backend/onyx/onyxbot/slack/listener.py @@ -14,6 +14,7 @@ from prometheus_client import Gauge from prometheus_client import start_http_server +from redis.lock import Lock from slack_sdk import WebClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse @@ -122,6 +123,9 @@ def __init__(self) -> None: self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {} self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {} + # Store Redis lock objects here so we can release them properly + self.redis_locks: Dict[str | None, Lock] = {} + self.running = True self.pod_id = self.get_pod_id() self._shutdown_event = Event() @@ -159,10 +163,15 @@ def acquire_tenants_loop(self) -> None: while not self._shutdown_event.is_set(): try: self.acquire_tenants() + + # After we finish acquiring and managing Slack bots, + # set the gauge to the number of active tenants (those with Slack bots). active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set( len(self.tenant_ids) ) - logger.debug(f"Current active tenants: {len(self.tenant_ids)}") + logger.debug( + f"Current active tenants with Slack bots: {len(self.tenant_ids)}" + ) except Exception as e: logger.exception(f"Error in Slack acquisition: {e}") self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL) @@ -171,7 +180,9 @@ def heartbeat_loop(self) -> None: while not self._shutdown_event.is_set(): try: self.send_heartbeats() - logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants") + logger.debug( + f"Sent heartbeats for {len(self.tenant_ids)} active tenants" + ) except Exception as e: logger.exception(f"Error in heartbeat loop: {e}") self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL) @@ -179,17 +190,21 @@ def heartbeat_loop(self) -> None: def _manage_clients_per_tenant( self, db_session: Session, tenant_id: str | None, bot: SlackBot ) -> None: + """ + - If the tokens are missing or empty, close the socket client and remove them. + - If the tokens have changed, close the existing socket client and reconnect. + - If the tokens are new, warm up the model and start a new socket client. + """ slack_bot_tokens = SlackBotTokens( bot_token=bot.bot_token, app_token=bot.app_token, ) tenant_bot_pair = (tenant_id, bot.id) - # If the tokens are not set, we need to close the socket client and delete the tokens - # for the tenant and app + # If the tokens are missing or empty, close the socket client and remove them. if not slack_bot_tokens: logger.debug( - f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}" + f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}" ) if tenant_bot_pair in self.socket_clients: asyncio.run(self.socket_clients[tenant_bot_pair].close()) @@ -204,9 +219,10 @@ def _manage_clients_per_tenant( if not tokens_exist or tokens_changed: if tokens_exist: logger.info( - f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting" + f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting" ) else: + # Warm up the model if needed search_settings = get_current_search_settings(db_session) embedding_model = EmbeddingModel.from_db_model( search_settings=search_settings, @@ -217,77 +233,168 @@ def _manage_clients_per_tenant( self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens + # Close any existing connection first if tenant_bot_pair in self.socket_clients: asyncio.run(self.socket_clients[tenant_bot_pair].close()) self.start_socket_client(bot.id, tenant_id, slack_bot_tokens) def acquire_tenants(self) -> None: - tenant_ids = get_all_tenant_ids() - - for tenant_id in tenant_ids: + """ + - Attempt to acquire a Redis lock for each tenant. + - If acquired, check if that tenant actually has Slack bots. + - If yes, store them in self.tenant_ids and manage the socket connections. + - If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope). + """ + all_tenants = get_all_tenant_ids() + + # 1) Try to acquire locks for new tenants + for tenant_id in all_tenants: if ( DISALLOWED_SLACK_BOT_TENANT_LIST is not None and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST ): - logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping") + logger.debug(f"Tenant {tenant_id} is disallowed; skipping.") continue + # Already acquired in a previous loop iteration? if tenant_id in self.tenant_ids: - logger.debug(f"Tenant {tenant_id} already in self.tenant_ids") continue + # Respect max tenant limit per pod if len(self.tenant_ids) >= MAX_TENANTS_PER_POD: logger.info( - f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants" + f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more." ) break redis_client = get_redis_client(tenant_id=tenant_id) - pod_id = self.pod_id - acquired = redis_client.set( - OnyxRedisLocks.SLACK_BOT_LOCK, - pod_id, - nx=True, - ex=TENANT_LOCK_EXPIRATION, + # Acquire a Redis lock (non-blocking) + rlock = redis_client.lock( + OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION ) - if not acquired and not DEV_MODE: - logger.debug(f"Another pod holds the lock for tenant {tenant_id}") - continue + lock_acquired = rlock.acquire(blocking=False) - logger.debug(f"Acquired lock for tenant {tenant_id}") + if not lock_acquired and not DEV_MODE: + logger.debug( + f"Another pod holds the lock for tenant {tenant_id}, skipping." + ) + continue - self.tenant_ids.add(tenant_id) + if lock_acquired: + logger.debug(f"Acquired lock for tenant {tenant_id}.") + self.redis_locks[tenant_id] = rlock + else: + # DEV_MODE will skip the lock acquisition guard + logger.debug( + f"Running in DEV_MODE. Not enforcing lock for {tenant_id}." + ) - for tenant_id in self.tenant_ids: + # Now check if this tenant actually has Slack bots token = CURRENT_TENANT_ID_CONTEXTVAR.set( tenant_id or POSTGRES_DEFAULT_SCHEMA ) try: with get_session_with_tenant(tenant_id) as db_session: + bots: list[SlackBot] = [] try: - bots = fetch_slack_bots(db_session=db_session) + bots = list(fetch_slack_bots(db_session=db_session)) + except KvKeyNotFoundError: + # No Slackbot tokens, pass + pass + except Exception as e: + logger.exception( + f"Error fetching Slack bots for tenant {tenant_id}: {e}" + ) + + if bots: + # Mark as active tenant + self.tenant_ids.add(tenant_id) for bot in bots: self._manage_clients_per_tenant( db_session=db_session, tenant_id=tenant_id, bot=bot, ) + else: + # If no Slack bots, release lock immediately (unless in DEV_MODE) + if lock_acquired and not DEV_MODE: + rlock.release() + del self.redis_locks[tenant_id] + logger.debug( + f"No Slack bots for tenant {tenant_id}; lock released (if held)." + ) + finally: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + # 2) Make sure tenants we're handling still have Slack bots + for tenant_id in list(self.tenant_ids): + token = CURRENT_TENANT_ID_CONTEXTVAR.set( + tenant_id or POSTGRES_DEFAULT_SCHEMA + ) + redis_client = get_redis_client(tenant_id=tenant_id) + try: + with get_session_with_tenant(tenant_id) as db_session: + # Attempt to fetch Slack bots + try: + bots = list(fetch_slack_bots(db_session=db_session)) except KvKeyNotFoundError: - logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}") - if (tenant_id, bot.id) in self.socket_clients: - asyncio.run(self.socket_clients[tenant_id, bot.id].close()) - del self.socket_clients[tenant_id, bot.id] - del self.slack_bot_tokens[tenant_id, bot.id] + # No Slackbot tokens, pass (and remove below) + bots = [] except Exception as e: logger.exception(f"Error handling tenant {tenant_id}: {e}") + bots = [] + + if not bots: + logger.info( + f"Tenant {tenant_id} no longer has Slack bots. Removing." + ) + self._remove_tenant(tenant_id) + + # NOTE: We release the lock here (in the same scope it was acquired) + if tenant_id in self.redis_locks and not DEV_MODE: + try: + self.redis_locks[tenant_id].release() + del self.redis_locks[tenant_id] + logger.info(f"Released lock for tenant {tenant_id}") + except Exception as e: + logger.error( + f"Error releasing lock for tenant {tenant_id}: {e}" + ) + else: + # Manage or reconnect Slack bot sockets + for bot in bots: + self._manage_clients_per_tenant( + db_session=db_session, + tenant_id=tenant_id, + bot=bot, + ) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + def _remove_tenant(self, tenant_id: str | None) -> None: + """ + Helper to remove a tenant from `self.tenant_ids` and close any socket clients. + (Lock release now happens in `acquire_tenants()`, not here.) + """ + # Close all socket clients for this tenant + for (t_id, slack_bot_id), client in list(self.socket_clients.items()): + if t_id == tenant_id: + asyncio.run(client.close()) + del self.socket_clients[(t_id, slack_bot_id)] + del self.slack_bot_tokens[(t_id, slack_bot_id)] + logger.info( + f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}" + ) + + # Remove from active set + if tenant_id in self.tenant_ids: + self.tenant_ids.remove(tenant_id) + def send_heartbeats(self) -> None: current_time = int(time.time()) - logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants") + logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants") for tenant_id in self.tenant_ids: redis_client = get_redis_client(tenant_id=tenant_id) heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}" @@ -315,6 +422,7 @@ def start_socket_client( ) socket_client.connect() self.socket_clients[tenant_id, slack_bot_id] = socket_client + # Ensure tenant is tracked as active self.tenant_ids.add(tenant_id) logger.info( f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}" @@ -322,7 +430,7 @@ def start_socket_client( def stop_socket_clients(self) -> None: logger.info(f"Stopping {len(self.socket_clients)} socket clients") - for (tenant_id, slack_bot_id), client in self.socket_clients.items(): + for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()): asyncio.run(client.close()) logger.info( f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}" @@ -340,17 +448,19 @@ def shutdown(self, signum: int | None, frame: FrameType | None) -> None: logger.info(f"Stopping {len(self.socket_clients)} socket clients") self.stop_socket_clients() - # Release locks for all tenants + # Release locks for all tenants we currently hold logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants") - for tenant_id in self.tenant_ids: - try: - redis_client = get_redis_client(tenant_id=tenant_id) - redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK) - logger.info(f"Released lock for tenant {tenant_id}") - except Exception as e: - logger.error(f"Error releasing lock for tenant {tenant_id}: {e}") - - # Wait for background threads to finish (with timeout) + for tenant_id in list(self.tenant_ids): + if tenant_id in self.redis_locks: + try: + self.redis_locks[tenant_id].release() + logger.info(f"Released lock for tenant {tenant_id}") + except Exception as e: + logger.error(f"Error releasing lock for tenant {tenant_id}: {e}") + finally: + del self.redis_locks[tenant_id] + + # Wait for background threads to finish (with a timeout) logger.info("Waiting for background threads to finish...") self.acquire_thread.join(timeout=5) self.heartbeat_thread.join(timeout=5)