Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass in tenant_id to kv_store in monitoring job #3726

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backend/onyx/key_value_store/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from onyx.key_value_store.store import PgRedisKVStore


def get_kv_store() -> KeyValueStore:
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore()
return PgRedisKVStore(tenant_id=tenant_id)
20 changes: 10 additions & 10 deletions backend/onyx/key_value_store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore):
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
Comment on lines 31 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly wary of default None to the current tenant ID. Elsewhere, we treat it as a valid tenant ID / as POSTGRES_DEFAULT_SCHEMA. Had a similar debate around get_session_with_tenant (default to a sentinel if nothing passed in?, etc.) and ended up deciding on different entrypoint (get_session_with_default_tenant)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
self.redis_client = get_redis_client(tenant_id=self.tenant_id)

@contextmanager
def get_session(self) -> Iterator[Session]:
def _get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(tenant_id):
if not is_valid_schema_name(self.tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
yield session

def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
Expand All @@ -66,7 +66,7 @@ def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:

encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self.get_session() as session:
with self._get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
Expand All @@ -88,7 +88,7 @@ def load(self, key: str) -> JSON_ro:
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")

with self.get_session() as session:
with self._get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
Expand All @@ -113,7 +113,7 @@ def delete(self, key: str) -> None:
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")

with self.get_session() as session:
with self._get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError
Expand Down
2 changes: 1 addition & 1 deletion backend/onyx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:

if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
get_or_generate_uuid(tenant_id=None)

# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:
Expand Down
24 changes: 13 additions & 11 deletions backend/onyx/utils/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.engine import get_session_with_tenant
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
Expand Down Expand Up @@ -41,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))


def get_or_generate_uuid(tenant_id: str | None = None) -> str:
def get_or_generate_uuid(tenant_id: str | None) -> str:
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.
Expand All @@ -52,7 +52,7 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str:
if _CACHED_UUID is not None:
return _CACHED_UUID

kv_store = get_kv_store()
kv_store = get_kv_store(tenant_id=tenant_id)

try:
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
Expand All @@ -63,18 +63,18 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str:
return _CACHED_UUID


def _get_or_generate_instance_domain() -> str | None: #
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
global _CACHED_INSTANCE_DOMAIN

if _CACHED_INSTANCE_DOMAIN is not None:
return _CACHED_INSTANCE_DOMAIN

kv_store = get_kv_store()
kv_store = get_kv_store(tenant_id=tenant_id)

try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
Expand All @@ -94,16 +94,16 @@ def optional_telemetry(
if DISABLE_TELEMETRY:
return

tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here


try:

def telemetry_logic() -> None:
try:
customer_uuid = (
_get_or_generate_customer_id_mt(
tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
)
_get_or_generate_customer_id_mt(tenant_id)
if MULTI_TENANT
else get_or_generate_uuid()
else get_or_generate_uuid(tenant_id)
)
payload = {
"data": data,
Expand All @@ -115,7 +115,9 @@ def telemetry_logic() -> None:
"is_cloud": MULTI_TENANT,
}
if ENTERPRISE_EDITION_ENABLED:
payload["instance_domain"] = _get_or_generate_instance_domain()
payload["instance_domain"] = _get_or_generate_instance_domain(
tenant_id
)
requests.post(
_DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},
Expand Down
Loading