From 89489552dd8c849d45f7ebdc75bafe19b995c567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20B=C3=A4uerle?= Date: Mon, 18 Sep 2023 17:07:22 +0200 Subject: [PATCH] fix: move references to amplitude key to after .env loading (#177) --- backend/zeno_backend/classes/amplitude.py | 30 +++++++++++++++++++++++ backend/zeno_backend/routers/sdk.py | 10 +++----- backend/zeno_backend/server.py | 13 +++++----- 3 files changed, 40 insertions(+), 13 deletions(-) create mode 100644 backend/zeno_backend/classes/amplitude.py diff --git a/backend/zeno_backend/classes/amplitude.py b/backend/zeno_backend/classes/amplitude.py new file mode 100644 index 00000000..5eed0496 --- /dev/null +++ b/backend/zeno_backend/classes/amplitude.py @@ -0,0 +1,30 @@ +"""Amplitude handler for tracking after env has been set up with a singleton object.""" +import os + +from amplitude import Amplitude, BaseEvent + + +class AmplitudeHandler: + """Class to handle amplitude events. + + Attributes: + _client (Amplitude | None): amplitude client singleton object. + """ + + _client = None + + def __new__(cls): + """Create a new amplitude handler if one doesn't exist.""" + if "AMPLITUDE_API_KEY" in os.environ: + if cls._client is None: + cls._client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) + return cls._client + + def track(self, event: BaseEvent): + """Track an amplitude event. + + Args: + event (BaseEvent): the event to track. + """ + if self._client is not None: + self._client.track(event) diff --git a/backend/zeno_backend/routers/sdk.py b/backend/zeno_backend/routers/sdk.py index f5c638c2..67cf28e5 100644 --- a/backend/zeno_backend/routers/sdk.py +++ b/backend/zeno_backend/routers/sdk.py @@ -1,10 +1,9 @@ """FastAPI server endpoints for the Zeno SDK.""" import io -import os import uuid import pandas as pd -from amplitude import Amplitude, BaseEvent +from amplitude import BaseEvent from fastapi import ( APIRouter, Depends, @@ -17,11 +16,10 @@ ) from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from zeno_backend.classes.amplitude import AmplitudeHandler from zeno_backend.classes.project import Project from zeno_backend.database import insert, select -amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - class APIKeyBearer(HTTPBearer): """API key bearer authentication scheme.""" @@ -97,7 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())): ) project.uuid = str(uuid.uuid4()) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Project Created", user_id="00000" + str(user_id), @@ -197,7 +195,7 @@ def upload_system( detail=("ERROR: Unable to create system table: " + str(e)), ) from e - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="System Uploaded", user_id="00000" + str(user_id), diff --git a/backend/zeno_backend/server.py b/backend/zeno_backend/server.py index 163f45cf..6ac51581 100644 --- a/backend/zeno_backend/server.py +++ b/backend/zeno_backend/server.py @@ -6,7 +6,7 @@ import pandas as pd import uvicorn -from amplitude import Amplitude, BaseEvent +from amplitude import BaseEvent from dotenv import load_dotenv from fastapi import Depends, FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware @@ -17,6 +17,7 @@ import zeno_backend.database.select as select import zeno_backend.database.update as update import zeno_backend.util as util +from zeno_backend.classes.amplitude import AmplitudeHandler from zeno_backend.classes.base import ( GroupMetric, ZenoColumn, @@ -45,8 +46,6 @@ from .routers import sdk -amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"]) - def get_server() -> FastAPI: """Provide the FastAPI server and specifies its inputs. @@ -352,7 +351,7 @@ def get_project(owner_name: str, project_name: str, request: Request): return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) if not util.access_valid(uuid, request): return Response(status_code=401) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Project Viewed", user_id="ProjectViewedUser", @@ -417,7 +416,7 @@ def get_projects(current_user=Depends(auth.claim())): tags=["zeno"], ) def get_public_projects(): - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="Home Viewed", user_id="HomeViewedUser", @@ -579,7 +578,7 @@ def login(name: str): try: user = User(id=-1, name=name, admin=None) user_id = insert.user(user) - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="User Registered", user_id="00000" + str(user_id) if user_id else "", @@ -593,7 +592,7 @@ def login(name: str): detail=str(exc), ) from exc else: - amplitude_client.track( + AmplitudeHandler().track( BaseEvent( event_type="User Logged In", user_id="00000" + str(fetched_user.id),