From 00dba4705209950c0c0195ec56f88e384c988f7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Ba=CC=88uerle?= Date: Sat, 16 Sep 2023 15:49:58 +0200 Subject: [PATCH 1/2] fix: move references to amplitude key to after .env loading Previously, the environment variable containing the amplitude API key was referenced before loading the .env file. This has been changed so that the key does not have to be set outside the .env file. --- backend/zeno_backend/routers/sdk.py | 4 ++-- backend/zeno_backend/server.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/zeno_backend/routers/sdk.py b/backend/zeno_backend/routers/sdk.py index ee63ee50..c53d8522 100644 --- a/backend/zeno_backend/routers/sdk.py +++ b/backend/zeno_backend/routers/sdk.py @@ -20,8 +20,6 @@ from zeno_backend.classes.project import Project from zeno_backend.database import insert, select -amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - class APIKeyBearer(HTTPBearer): """API key bearer authentication scheme.""" @@ -97,6 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())): ) project.uuid = str(uuid.uuid4()) + amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) amplitude_client.track( BaseEvent( event_type="Project Created", @@ -197,6 +196,7 @@ def upload_system( detail=("ERROR: Unable to create system table: " + str(e)), ) from e + amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) amplitude_client.track( BaseEvent( event_type="System Uploaded", diff --git a/backend/zeno_backend/server.py b/backend/zeno_backend/server.py index 5f8431cd..cf0bbcee 100644 --- a/backend/zeno_backend/server.py +++ b/backend/zeno_backend/server.py @@ -44,8 +44,6 @@ from .routers import sdk -amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - def get_server() -> FastAPI: """Provide the FastAPI server and specifies its inputs. @@ -65,6 +63,8 @@ def get_server() -> FastAPI: if env_path.exists(): load_dotenv(env_path) + amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) + # function to get the user from cognito auth = Cognito( region=os.environ["ZENO_USER_POOL_AUTH_REGION"], From 80d97e0f6b7fd5b0f455f31190c3c9653c859990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Ba=CC=88uerle?= Date: Sun, 17 Sep 2023 15:03:12 +0200 Subject: [PATCH 2/2] feat: use singleton for amplitude --- backend/zeno_backend/classes/amplitude.py | 30 +++++++++++++++++++++++ backend/zeno_backend/routers/sdk.py | 10 +++----- backend/zeno_backend/server.py | 13 +++++----- 3 files changed, 40 insertions(+), 13 deletions(-) create mode 100644 backend/zeno_backend/classes/amplitude.py diff --git a/backend/zeno_backend/classes/amplitude.py b/backend/zeno_backend/classes/amplitude.py new file mode 100644 index 00000000..5eed0496 --- /dev/null +++ b/backend/zeno_backend/classes/amplitude.py @@ -0,0 +1,30 @@ +"""Amplitude handler for tracking after env has been set up with a singleton object.""" +import os + +from amplitude import Amplitude, BaseEvent + + +class AmplitudeHandler: + """Class to handle amplitude events. + + Attributes: + _client (Amplitude | None): amplitude client singleton object. + """ + + _client = None + + def __new__(cls): + """Create a new amplitude handler if one doesn't exist.""" + if "AMPLITUDE_API_KEY" in os.environ: + if cls._client is None: + cls._client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) + return cls._client + + def track(self, event: BaseEvent): + """Track an amplitude event. + + Args: + event (BaseEvent): the event to track. + """ + if self._client is not None: + self._client.track(event) diff --git a/backend/zeno_backend/routers/sdk.py b/backend/zeno_backend/routers/sdk.py index c53d8522..adc2bb54 100644 --- a/backend/zeno_backend/routers/sdk.py +++ b/backend/zeno_backend/routers/sdk.py @@ -1,10 +1,9 @@ """FastAPI server endpoints for the Zeno SDK.""" import io -import os import uuid import pandas as pd -from amplitude import Amplitude, BaseEvent +from amplitude import BaseEvent from fastapi import ( APIRouter, Depends, @@ -17,6 +16,7 @@ ) from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from zeno_backend.classes.amplitude import AmplitudeHandler from zeno_backend.classes.project import Project from zeno_backend.database import insert, select @@ -95,8 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())): ) project.uuid = str(uuid.uuid4()) - amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Project Created", user_id="00000" + str(user_id), @@ -196,8 +195,7 @@ def upload_system( detail=("ERROR: Unable to create system table: " + str(e)), ) from e - amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="System Uploaded", user_id="00000" + str(user_id), diff --git a/backend/zeno_backend/server.py b/backend/zeno_backend/server.py index cf0bbcee..f9fd6472 100644 --- a/backend/zeno_backend/server.py +++ b/backend/zeno_backend/server.py @@ -5,7 +5,7 @@ import pandas as pd import uvicorn -from amplitude import Amplitude, BaseEvent +from amplitude import BaseEvent from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -16,6 +16,7 @@ import zeno_backend.database.select as select import zeno_backend.database.update as update import zeno_backend.util as util +from zeno_backend.classes.amplitude import AmplitudeHandler from zeno_backend.classes.base import ( GroupMetric, ZenoColumn, @@ -63,8 +64,6 @@ def get_server() -> FastAPI: if env_path.exists(): load_dotenv(env_path) - amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - # function to get the user from cognito auth = Cognito( region=os.environ["ZENO_USER_POOL_AUTH_REGION"], @@ -351,7 +350,7 @@ def get_project(owner_name: str, project_name: str, request: Request): return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) if not util.access_valid(uuid, request): return Response(status_code=401) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Project Viewed", user_id="ProjectViewedUser", @@ -416,7 +415,7 @@ def get_projects(current_user=Depends(auth.claim())): tags=["zeno"], ) def get_public_projects(): - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Home Viewed", user_id="HomeViewedUser", @@ -553,7 +552,7 @@ def login(name: str): try: user = User(id=-1, name=name, admin=None) user_id = insert.user(user) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="User Registered", user_id="00000" + str(user_id) if user_id else "", @@ -567,7 +566,7 @@ def login(name: str): detail=str(exc), ) from exc else: - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="User Logged In", user_id="00000" + str(fetched_user.id),