diff --git a/backend/zeno_backend/classes/amplitude.py b/backend/zeno_backend/classes/amplitude.py new file mode 100644 index 000000000..5ae6872fa --- /dev/null +++ b/backend/zeno_backend/classes/amplitude.py @@ -0,0 +1,25 @@ +"""Amplitude handler for tracking after env has been set up with a singleton object.""" +import os + +from amplitude import Amplitude, BaseEvent + + +class AmplitudeHandler: + """Class to handle amplitude events. + + Attributes: + _client (Amplitude | None): amplitude client singleton object. + """ + + _client = None + + def track(self, event: BaseEvent): + """Track an amplitude event. + + Args: + event (BaseEvent): the event to track. + """ + if "AMPLITUDE_API_KEY" in os.environ: + if self._client is None: + self._client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) + self._client.track(event) diff --git a/backend/zeno_backend/routers/sdk.py b/backend/zeno_backend/routers/sdk.py index c53d8522e..adc2bb54c 100644 --- a/backend/zeno_backend/routers/sdk.py +++ b/backend/zeno_backend/routers/sdk.py @@ -1,10 +1,9 @@ """FastAPI server endpoints for the Zeno SDK.""" import io -import os import uuid import pandas as pd -from amplitude import Amplitude, BaseEvent +from amplitude import BaseEvent from fastapi import ( APIRouter, Depends, @@ -17,6 +16,7 @@ ) from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from zeno_backend.classes.amplitude import AmplitudeHandler from zeno_backend.classes.project import Project from zeno_backend.database import insert, select @@ -95,8 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())): ) project.uuid = str(uuid.uuid4()) - amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Project Created", user_id="00000" + str(user_id), @@ -196,8 +195,7 @@ def upload_system( detail=("ERROR: Unable to create system table: " + str(e)), ) from e - amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="System Uploaded", user_id="00000" + str(user_id), diff --git a/backend/zeno_backend/server.py b/backend/zeno_backend/server.py index cf0bbcee2..f9fd64726 100644 --- a/backend/zeno_backend/server.py +++ b/backend/zeno_backend/server.py @@ -5,7 +5,7 @@ import pandas as pd import uvicorn -from amplitude import Amplitude, BaseEvent +from amplitude import BaseEvent from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -16,6 +16,7 @@ import zeno_backend.database.select as select import zeno_backend.database.update as update import zeno_backend.util as util +from zeno_backend.classes.amplitude import AmplitudeHandler from zeno_backend.classes.base import ( GroupMetric, ZenoColumn, @@ -63,8 +64,6 @@ def get_server() -> FastAPI: if env_path.exists(): load_dotenv(env_path) - amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - # function to get the user from cognito auth = Cognito( region=os.environ["ZENO_USER_POOL_AUTH_REGION"], @@ -351,7 +350,7 @@ def get_project(owner_name: str, project_name: str, request: Request): return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) if not util.access_valid(uuid, request): return Response(status_code=401) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Project Viewed", user_id="ProjectViewedUser", @@ -416,7 +415,7 @@ def get_projects(current_user=Depends(auth.claim())): tags=["zeno"], ) def get_public_projects(): - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Home Viewed", user_id="HomeViewedUser", @@ -553,7 +552,7 @@ def login(name: str): try: user = User(id=-1, name=name, admin=None) user_id = insert.user(user) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="User Registered", user_id="00000" + str(user_id) if user_id else "", @@ -567,7 +566,7 @@ def login(name: str): detail=str(exc), ) from exc else: - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="User Logged In", user_id="00000" + str(fetched_user.id),