Skip to content

Commit

Permalink
fix: move references to amplitude key to after .env loading (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sparkier authored Sep 18, 2023
1 parent 6755d9e commit 8948955
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
30 changes: 30 additions & 0 deletions backend/zeno_backend/classes/amplitude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Amplitude handler for tracking after env has been set up with a singleton object."""
import os

from amplitude import Amplitude, BaseEvent


class AmplitudeHandler:
"""Class to handle amplitude events.
Attributes:
_client (Amplitude | None): amplitude client singleton object.
"""

_client = None

def __new__(cls):
"""Create a new amplitude handler if one doesn't exist."""
if "AMPLITUDE_API_KEY" in os.environ:
if cls._client is None:
cls._client = Amplitude(os.environ["AMPLITUDE_API_KEY"])
return cls._client

def track(self, event: BaseEvent):
"""Track an amplitude event.
Args:
event (BaseEvent): the event to track.
"""
if self._client is not None:
self._client.track(event)
10 changes: 4 additions & 6 deletions backend/zeno_backend/routers/sdk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""FastAPI server endpoints for the Zeno SDK."""
import io
import os
import uuid

import pandas as pd
from amplitude import Amplitude, BaseEvent
from amplitude import BaseEvent
from fastapi import (
APIRouter,
Depends,
Expand All @@ -17,11 +16,10 @@
)
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from zeno_backend.classes.amplitude import AmplitudeHandler
from zeno_backend.classes.project import Project
from zeno_backend.database import insert, select

amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])


class APIKeyBearer(HTTPBearer):
"""API key bearer authentication scheme."""
Expand Down Expand Up @@ -97,7 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())):
)

project.uuid = str(uuid.uuid4())
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Project Created",
user_id="00000" + str(user_id),
Expand Down Expand Up @@ -197,7 +195,7 @@ def upload_system(
detail=("ERROR: Unable to create system table: " + str(e)),
) from e

amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="System Uploaded",
user_id="00000" + str(user_id),
Expand Down
13 changes: 6 additions & 7 deletions backend/zeno_backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pandas as pd
import uvicorn
from amplitude import Amplitude, BaseEvent
from amplitude import BaseEvent
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -17,6 +17,7 @@
import zeno_backend.database.select as select
import zeno_backend.database.update as update
import zeno_backend.util as util
from zeno_backend.classes.amplitude import AmplitudeHandler
from zeno_backend.classes.base import (
GroupMetric,
ZenoColumn,
Expand Down Expand Up @@ -45,8 +46,6 @@

from .routers import sdk

amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])


def get_server() -> FastAPI:
"""Provide the FastAPI server and specifies its inputs.
Expand Down Expand Up @@ -352,7 +351,7 @@ def get_project(owner_name: str, project_name: str, request: Request):
return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
if not util.access_valid(uuid, request):
return Response(status_code=401)
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Project Viewed",
user_id="ProjectViewedUser",
Expand Down Expand Up @@ -417,7 +416,7 @@ def get_projects(current_user=Depends(auth.claim())):
tags=["zeno"],
)
def get_public_projects():
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Home Viewed",
user_id="HomeViewedUser",
Expand Down Expand Up @@ -579,7 +578,7 @@ def login(name: str):
try:
user = User(id=-1, name=name, admin=None)
user_id = insert.user(user)
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="User Registered",
user_id="00000" + str(user_id) if user_id else "",
Expand All @@ -593,7 +592,7 @@ def login(name: str):
detail=str(exc),
) from exc
else:
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="User Logged In",
user_id="00000" + str(fetched_user.id),
Expand Down

0 comments on commit 8948955

Please sign in to comment.