Skip to content

Commit

Permalink
fix: speed up routes
Browse files Browse the repository at this point in the history
  • Loading branch information
cabreraalex committed Sep 11, 2023
1 parent d6a1ecd commit 27e3bca
Show file tree
Hide file tree
Showing 26 changed files with 741 additions and 412 deletions.
12 changes: 12 additions & 0 deletions backend/zeno_backend/classes/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,15 @@ def default(
object: a dict to be encoded by a JSON encoder and saved into the database.
"""
return o.__dict__


class ChartResponse(CamelModel):
"""Chart specification and data.
Parameters:
chart (Chart): The chart specification.
chart_data (str): The chart data in JSON string.
"""

chart: Chart
chart_data: str
27 changes: 26 additions & 1 deletion backend/zeno_backend/classes/project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""Types for Zeno projects."""

from zeno_backend.classes.base import CamelModel
from zeno_backend.classes.base import CamelModel, ZenoColumn
from zeno_backend.classes.folder import Folder
from zeno_backend.classes.metric import Metric
from zeno_backend.classes.slice import Slice
from zeno_backend.classes.tag import Tag


class Project(CamelModel):
Expand Down Expand Up @@ -45,3 +48,25 @@ class ProjectStats(CamelModel):
num_instances: int
num_charts: int
num_models: int


class ProjectState(CamelModel):
"""State variables for a Zeno project.
Attributes:
project (Project): The project object with project metadata.
models (list[str]): The names of the models in the project.
metrics (list[Metric]): The metrics to calculate for the project.
columns (list[ZenoColumn]): The columns in the project.
slices (list[Slice]): The slices in the project.
tags (list[Tag]): The tags in the project.
folders (list[Folder]): The folders in the project.
"""

project: Project
models: list[str]
metrics: list[Metric]
columns: list[ZenoColumn]
slices: list[Slice]
tags: list[Tag]
folders: list[Folder]
157 changes: 156 additions & 1 deletion backend/zeno_backend/database/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zeno_backend.classes.folder import Folder
from zeno_backend.classes.metadata import StringFilterRequest
from zeno_backend.classes.metric import Metric
from zeno_backend.classes.project import Project, ProjectStats
from zeno_backend.classes.project import Project, ProjectState, ProjectStats
from zeno_backend.classes.slice import Slice
from zeno_backend.classes.slice_finder import SQLTable
from zeno_backend.classes.tag import Tag
Expand Down Expand Up @@ -380,6 +380,131 @@ def project_from_uuid(project_uuid: str) -> Project | None:
)


def project_state(project_uuid: str, project: Project) -> ProjectState | None:
"""Get the state variables of a project.
Args:
project_uuid (str): the uuid of the project to be fetched.
project (Project): the project object with project metadata.
Returns:
ProjectState | None: state variables of the requested project.
"""
with Database() as db:
metric_results = db.execute_return(
"SELECT id, name, type, columns FROM metrics WHERE project_uuid = %s;",
[project_uuid],
)
metrics = list(
map(
lambda metric: Metric(
id=metric[0],
name=metric[1],
type=metric[2],
columns=metric[3],
),
metric_results,
)
)

slice_results = db.execute_return(
"SELECT id, name, folder_id, filter FROM slices WHERE project_uuid = %s;",
[
project_uuid,
],
)
slices = list(
map(
lambda slice: Slice(
id=slice[0],
slice_name=slice[1],
folder_id=slice[2],
filter_predicates=FilterPredicateGroup(
predicates=json.loads(slice[3])["predicates"],
join=Join.OMITTED,
),
),
slice_results,
)
)

folder_results = db.execute_return(
"SELECT id, name, project_uuid FROM folders WHERE project_uuid = %s;",
[
project_uuid,
],
)
folders = list(
map(
lambda folder: Folder(id=folder[0], name=folder[1]),
folder_results,
)
)

tags_result = db.execute_return(
"SELECT id, name, folder_id FROM tags WHERE project_uuid = %s",
[
project_uuid,
],
)
tags: list[Tag] = []
for tag_result in tags_result:
data_results = db.execute_return(
sql.SQL("SELECT data_id FROM {} WHERE tag_id = %s").format(
sql.Identifier(f"{project_uuid}_tags_datapoints")
),
[
tag_result[0],
],
)
tags.append(
Tag(
id=tag_result[0],
tag_name=tag_result[1],
folder_id=tag_result[2],
data_ids=[]
if len(data_results) == 0
else [d[0] for d in data_results],
)
)

column_results = db.execute_return(
sql.SQL("SELECT column_id, name, type, model, data_type FROM {};").format(
sql.Identifier(f"{project_uuid}_column_map")
),
)

columns = list(
map(
lambda column: ZenoColumn(
id=column[0],
name=column[1],
column_type=column[2],
model=column[3],
data_type=column[4],
),
column_results,
)
)

model_results = db.execute_return(
sql.SQL("SELECT DISTINCT model FROM {} WHERE model IS NOT NULL;").format(
sql.Identifier(f"{project_uuid}_column_map")
),
)
models = [m[0] for m in model_results]

return ProjectState(
project=project,
metrics=metrics,
folders=folders,
columns=columns,
slices=slices,
tags=tags,
models=models,
)


def project_stats(project: str) -> ProjectStats | None:
"""Get statistics for a specified project.
Expand Down Expand Up @@ -567,6 +692,36 @@ def slices(project: str, ids: list[int] | None = None) -> list[Slice]:
)


def chart(project_uuid: str, chart_id: int):
"""Get a project chart by its ID.
Args:
project_uuid (str): the project the user is currently working with.
chart_id (int): the ID of the chart to be fetched.
Returns:
Chart | None: the requested chart.
"""
db = Database()
chart_result = db.connect_execute_return(
"SELECT id, name, type, parameters FROM "
"charts WHERE id = %s AND project_uuid = %s;",
[
chart_id,
project_uuid,
],
)
if len(chart_result) == 0:
return None
return Chart(
id=chart_result[0][0],
name=chart_result[0][1],
type=chart_result[0][2],
parameters=json.loads(chart_result[0][3]),
)


def charts(project: str) -> list[Chart]:
"""Get a list of all charts created in the project.
Expand Down
71 changes: 59 additions & 12 deletions backend/zeno_backend/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
GroupMetric,
ZenoColumn,
)
from zeno_backend.classes.chart import Chart
from zeno_backend.classes.chart import Chart, ChartResponse
from zeno_backend.classes.folder import Folder
from zeno_backend.classes.metadata import HistogramBucket, StringFilterRequest
from zeno_backend.classes.metric import Metric, MetricRequest
from zeno_backend.classes.project import Project, ProjectStats
from zeno_backend.classes.project import Project, ProjectState, ProjectStats
from zeno_backend.classes.slice import Slice
from zeno_backend.classes.slice_finder import SliceFinderRequest, SliceFinderReturn
from zeno_backend.classes.table import TableRequest
Expand Down Expand Up @@ -175,14 +175,19 @@ def get_slices(project: str, request: Request):
return select.slices(project)

@api_app.get(
"/charts/{project}",
"/charts/{owner}/{project}",
response_model=list[Chart],
tags=["zeno"],
)
def get_charts(project: str, request: Request):
if not util.access_valid(project, request):
def get_charts(owner_name: str, project_name: str, request: Request):
project_uuid = select.project_uuid(owner_name, project_name)
if project_uuid is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
if not util.access_valid(project_uuid, request):
return Response(status_code=401)
return select.charts(project)
return select.charts(project_uuid)

@api_app.get(
"/columns/{project}",
Expand Down Expand Up @@ -260,15 +265,30 @@ def get_filtered_table(req: TableRequest, project_uuid: str, request: Request):
)
return filt_df.to_json(orient="records")

@api_app.post(
"/chart-data/{project}",
@api_app.get(
"/chart/{owner}/{project}/{chart_uuid}",
response_model=ChartResponse,
tags=["zeno"],
response_model=str,
)
def get_chart_data(chart: Chart, project: str, request: Request):
if not util.access_valid(project, request):
def get_chart(owner_name: str, project_name: str, chart_id: int, request: Request):
project_uuid = select.project_uuid(owner_name, project_name)
if project_uuid is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
project = select.project_from_uuid(project_uuid)
if project is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
if not util.access_valid(project_uuid, request):
return Response(status_code=401)
return chart_data(chart, project)
chart = select.chart(project_uuid, chart_id)
if chart is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Chart not found"
)
return ChartResponse(chart=chart, chart_data=chart_data(chart, project_uuid))

@api_app.post("/organizations", tags=["zeno"], response_model=list[Organization])
def get_organizations(current_user=Depends(auth.claim())):
Expand All @@ -281,6 +301,33 @@ def get_organizations(current_user=Depends(auth.claim())):
def is_project_public(project_uuid: str):
return select.project_public(project_uuid)

@api_app.get(
"/project-state/{owner}/{project}", response_model=ProjectState, tags=["zeno"]
)
def get_project_state(
owner_name: str,
project_name: str,
request: Request,
current_user=Depends(auth.claim()),
):
project_uuid = select.project_uuid(owner_name, project_name)
if project_uuid is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
project = select.project_from_uuid(project_uuid)
if project is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Project not found"
)
if not util.access_valid(project_uuid, request):
return Response(status_code=401)
user = select.user(current_user["username"])
if user is not None:
if user.name == project.owner_name:
project.editor = True
return select.project_state(project_uuid, project)

@api_app.post("/project/{owner}/{project}", response_model=Project, tags=["zeno"])
def get_project(owner_name: str, project_name: str, request: Request):
uuid = select.project_uuid(owner_name, project_name)
Expand Down
Loading

0 comments on commit 27e3bca

Please sign in to comment.