Skip to content

Commit

Permalink
feat: use singleton for amplitude
Browse files Browse the repository at this point in the history
  • Loading branch information
Sparkier committed Sep 17, 2023
1 parent 00dba47 commit ba9b038
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
25 changes: 25 additions & 0 deletions backend/zeno_backend/classes/amplitude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Amplitude handler for tracking after env has been set up with a singleton object."""
import os

from amplitude import Amplitude, BaseEvent


class AmplitudeHandler:
"""Class to handle amplitude events.
Attributes:
_client (Amplitude | None): amplitude client singleton object.
"""

_client = None

def track(self, event: BaseEvent):
"""Track an amplitude event.
Args:
event (BaseEvent): the event to track.
"""
if "AMPLITUDE_API_KEY" in os.environ:
if self._client is None:
self._client = Amplitude(os.environ["AMPLITUDE_API_KEY"])
self._client.track(event)
10 changes: 4 additions & 6 deletions backend/zeno_backend/routers/sdk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""FastAPI server endpoints for the Zeno SDK."""
import io
import os
import uuid

import pandas as pd
from amplitude import Amplitude, BaseEvent
from amplitude import BaseEvent
from fastapi import (
APIRouter,
Depends,
Expand All @@ -17,6 +16,7 @@
)
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from zeno_backend.classes.amplitude import AmplitudeHandler
from zeno_backend.classes.project import Project
from zeno_backend.database import insert, select

Expand Down Expand Up @@ -95,8 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())):
)

project.uuid = str(uuid.uuid4())
amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Project Created",
user_id="00000" + str(user_id),
Expand Down Expand Up @@ -196,8 +195,7 @@ def upload_system(
detail=("ERROR: Unable to create system table: " + str(e)),
) from e

amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="System Uploaded",
user_id="00000" + str(user_id),
Expand Down
13 changes: 6 additions & 7 deletions backend/zeno_backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd
import uvicorn
from amplitude import Amplitude, BaseEvent
from amplitude import BaseEvent
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -16,6 +16,7 @@
import zeno_backend.database.select as select
import zeno_backend.database.update as update
import zeno_backend.util as util
from zeno_backend.classes.amplitude import AmplitudeHandler
from zeno_backend.classes.base import (
GroupMetric,
ZenoColumn,
Expand Down Expand Up @@ -63,8 +64,6 @@ def get_server() -> FastAPI:
if env_path.exists():
load_dotenv(env_path)

amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])

# function to get the user from cognito
auth = Cognito(
region=os.environ["ZENO_USER_POOL_AUTH_REGION"],
Expand Down Expand Up @@ -351,7 +350,7 @@ def get_project(owner_name: str, project_name: str, request: Request):
return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
if not util.access_valid(uuid, request):
return Response(status_code=401)
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Project Viewed",
user_id="ProjectViewedUser",
Expand Down Expand Up @@ -416,7 +415,7 @@ def get_projects(current_user=Depends(auth.claim())):
tags=["zeno"],
)
def get_public_projects():
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Home Viewed",
user_id="HomeViewedUser",
Expand Down Expand Up @@ -553,7 +552,7 @@ def login(name: str):
try:
user = User(id=-1, name=name, admin=None)
user_id = insert.user(user)
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="User Registered",
user_id="00000" + str(user_id) if user_id else "",
Expand All @@ -567,7 +566,7 @@ def login(name: str):
detail=str(exc),
) from exc
else:
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="User Logged In",
user_id="00000" + str(fetched_user.id),
Expand Down

0 comments on commit ba9b038

Please sign in to comment.