diff --git a/backend/onyx/key_value_store/factory.py b/backend/onyx/key_value_store/factory.py index c53f7ebac6c..77f8ea79fe7 100644 --- a/backend/onyx/key_value_store/factory.py +++ b/backend/onyx/key_value_store/factory.py @@ -2,7 +2,7 @@ from onyx.key_value_store.store import PgRedisKVStore -def get_kv_store() -> KeyValueStore: +def get_kv_store(tenant_id: str | None = None) -> KeyValueStore: # In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in # It's read from the global thread level variable - return PgRedisKVStore() + return PgRedisKVStore(tenant_id=tenant_id) diff --git a/backend/onyx/key_value_store/store.py b/backend/onyx/key_value_store/store.py index b252c17dc62..6db1b6ce10b 100644 --- a/backend/onyx/key_value_store/store.py +++ b/backend/onyx/key_value_store/store.py @@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore): def __init__( self, redis_client: Redis | None = None, tenant_id: str | None = None ) -> None: + self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get() + # If no redis_client is provided, fall back to the context var if redis_client is not None: self.redis_client = redis_client else: - tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get() - self.redis_client = get_redis_client(tenant_id=tenant_id) + self.redis_client = get_redis_client(tenant_id=self.tenant_id) @contextmanager - def get_session(self) -> Iterator[Session]: + def _get_session(self) -> Iterator[Session]: engine = get_sqlalchemy_engine() with Session(engine, expire_on_commit=False) as session: if MULTI_TENANT: - tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() - if tenant_id == POSTGRES_DEFAULT_SCHEMA: + if self.tenant_id == POSTGRES_DEFAULT_SCHEMA: raise HTTPException( status_code=401, detail="User must authenticate" ) - if not is_valid_schema_name(tenant_id): + if not is_valid_schema_name(self.tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") # Set the search_path to the tenant's schema - session.execute(text(f'SET search_path = "{tenant_id}"')) + session.execute(text(f'SET search_path = "{self.tenant_id}"')) yield session def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: @@ -66,7 +66,7 @@ def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: encrypted_val = val if encrypt else None plain_val = val if not encrypt else None - with self.get_session() as session: + with self._get_session() as session: obj = session.query(KVStore).filter_by(key=key).first() if obj: obj.value = plain_val @@ -88,7 +88,7 @@ def load(self, key: str) -> JSON_ro: except Exception as e: logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}") - with self.get_session() as session: + with self._get_session() as session: obj = session.query(KVStore).filter_by(key=key).first() if not obj: raise KvKeyNotFoundError @@ -113,7 +113,7 @@ def delete(self, key: str) -> None: except Exception as e: logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}") - with self.get_session() as session: + with self._get_session() as session: result = session.query(KVStore).filter_by(key=key).delete() # type: ignore if result == 0: raise KvKeyNotFoundError diff --git a/backend/onyx/main.py b/backend/onyx/main.py index c2917c0e41e..05150fee470 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -212,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: if not MULTI_TENANT: # We cache this at the beginning so there is no delay in the first telemetry - get_or_generate_uuid() + get_or_generate_uuid(tenant_id=None) # If we are multi-tenant, we need to only set up initial public tables with Session(engine) as db_session: diff --git a/backend/onyx/utils/telemetry.py b/backend/onyx/utils/telemetry.py index b0c38da2db4..36898efb591 100644 --- a/backend/onyx/utils/telemetry.py +++ b/backend/onyx/utils/telemetry.py @@ -11,7 +11,7 @@ from onyx.configs.constants import KV_CUSTOMER_UUID_KEY from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY from onyx.configs.constants import MilestoneRecordType -from onyx.db.engine import get_sqlalchemy_engine +from onyx.db.engine import get_session_with_tenant from onyx.db.milestone import create_milestone_if_not_exists from onyx.db.models import User from onyx.key_value_store.factory import get_kv_store @@ -41,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str: return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id)) -def get_or_generate_uuid(tenant_id: str | None = None) -> str: +def get_or_generate_uuid(tenant_id: str | None) -> str: # TODO: split out the whole "instance UUID" generation logic into a separate # utility function. Telemetry should not be aware at all of how the UUID is # generated/stored. @@ -52,7 +52,7 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str: if _CACHED_UUID is not None: return _CACHED_UUID - kv_store = get_kv_store() + kv_store = get_kv_store(tenant_id=tenant_id) try: _CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) @@ -63,18 +63,18 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str: return _CACHED_UUID -def _get_or_generate_instance_domain() -> str | None: # +def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: # global _CACHED_INSTANCE_DOMAIN if _CACHED_INSTANCE_DOMAIN is not None: return _CACHED_INSTANCE_DOMAIN - kv_store = get_kv_store() + kv_store = get_kv_store(tenant_id=tenant_id) try: _CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY)) except KvKeyNotFoundError: - with Session(get_sqlalchemy_engine()) as db_session: + with get_session_with_tenant(tenant_id=tenant_id) as db_session: first_user = db_session.query(User).first() if first_user: _CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1] @@ -94,16 +94,16 @@ def optional_telemetry( if DISABLE_TELEMETRY: return + tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get() + try: def telemetry_logic() -> None: try: customer_uuid = ( - _get_or_generate_customer_id_mt( - tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get() - ) + _get_or_generate_customer_id_mt(tenant_id) if MULTI_TENANT - else get_or_generate_uuid() + else get_or_generate_uuid(tenant_id) ) payload = { "data": data, @@ -115,7 +115,9 @@ def telemetry_logic() -> None: "is_cloud": MULTI_TENANT, } if ENTERPRISE_EDITION_ENABLED: - payload["instance_domain"] = _get_or_generate_instance_domain() + payload["instance_domain"] = _get_or_generate_instance_domain( + tenant_id + ) requests.post( _DANSWER_TELEMETRY_ENDPOINT, headers={"Content-Type": "application/json"},