Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move references to amplitude key to after .env loading #177

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions backend/zeno_backend/classes/amplitude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Amplitude handler for tracking after env has been set up with a singleton object."""
import os

from amplitude import Amplitude, BaseEvent


class AmplitudeHandler:
"""Class to handle amplitude events.

Attributes:
_client (Amplitude | None): amplitude client singleton object.
"""

_client = None

def __new__(cls):
"""Create a new amplitude handler if one doesn't exist."""
if "AMPLITUDE_API_KEY" in os.environ:
if cls._client is None:
cls._client = Amplitude(os.environ["AMPLITUDE_API_KEY"])
return cls._client

def track(self, event: BaseEvent):
"""Track an amplitude event.

Args:
event (BaseEvent): the event to track.
"""
if self._client is not None:
self._client.track(event)
10 changes: 4 additions & 6 deletions backend/zeno_backend/routers/sdk.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""FastAPI server endpoints for the Zeno SDK."""
import io
import os
import uuid

import pandas as pd
from amplitude import Amplitude, BaseEvent
from amplitude import BaseEvent
from fastapi import (
APIRouter,
Depends,
Expand All @@ -17,11 +16,10 @@
)
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from zeno_backend.classes.amplitude import AmplitudeHandler
from zeno_backend.classes.project import Project
from zeno_backend.database import insert, select

amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])


class APIKeyBearer(HTTPBearer):
"""API key bearer authentication scheme."""
Expand Down Expand Up @@ -97,7 +95,7 @@ def create_project(project: Project, api_key=Depends(APIKeyBearer())):
)

project.uuid = str(uuid.uuid4())
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Project Created",
user_id="00000" + str(user_id),
Expand Down Expand Up @@ -197,7 +195,7 @@ def upload_system(
detail=("ERROR: Unable to create system table: " + str(e)),
) from e

amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="System Uploaded",
user_id="00000" + str(user_id),
Expand Down
13 changes: 6 additions & 7 deletions backend/zeno_backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd
import uvicorn
from amplitude import Amplitude, BaseEvent
from amplitude import BaseEvent
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -16,6 +16,7 @@
import zeno_backend.database.select as select
import zeno_backend.database.update as update
import zeno_backend.util as util
from zeno_backend.classes.amplitude import AmplitudeHandler
from zeno_backend.classes.base import (
GroupMetric,
ZenoColumn,
Expand Down Expand Up @@ -44,8 +45,6 @@

from .routers import sdk

amplitude_client = Amplitude(os.environ["AMPLITUDE_API_KEY"])


def get_server() -> FastAPI:
"""Provide the FastAPI server and specifies its inputs.
Expand Down Expand Up @@ -351,7 +350,7 @@ def get_project(owner_name: str, project_name: str, request: Request):
return Response(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
if not util.access_valid(uuid, request):
return Response(status_code=401)
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Project Viewed",
user_id="ProjectViewedUser",
Expand Down Expand Up @@ -416,7 +415,7 @@ def get_projects(current_user=Depends(auth.claim())):
tags=["zeno"],
)
def get_public_projects():
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="Home Viewed",
user_id="HomeViewedUser",
Expand Down Expand Up @@ -553,7 +552,7 @@ def login(name: str):
try:
user = User(id=-1, name=name, admin=None)
user_id = insert.user(user)
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="User Registered",
user_id="00000" + str(user_id) if user_id else "",
Expand All @@ -567,7 +566,7 @@ def login(name: str):
detail=str(exc),
) from exc
else:
amplitude_client.track(
AmplitudeHandler().track(
BaseEvent(
event_type="User Logged In",
user_id="00000" + str(fetched_user.id),
Expand Down