Skip to content

Commit

Permalink
feat: fast histogram calculation w/ sql (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
cabreraalex authored Sep 12, 2023
1 parent c48b35a commit 311926d
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 51 deletions.
151 changes: 126 additions & 25 deletions backend/zeno_backend/processing/histogram_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
HistogramColumnRequest,
HistogramRequest,
)
from zeno_backend.database.select import column, project_from_uuid
from zeno_backend.processing.filtering import bucket_filter, table_filter
from zeno_backend.processing.metrics.map import metric_map
from zeno_backend.database.database import Database
from zeno_backend.database.select import (
column,
column_id_from_name_and_model,
project_from_uuid,
)
from zeno_backend.processing.filtering import table_filter


def histogram_bucket(project_uuid: str, col: ZenoColumn, num_bins: int | str):
Expand Down Expand Up @@ -89,11 +93,10 @@ def histogram_buckets(
def histogram_metric_task(
request: HistogramRequest,
col_request: HistogramColumnRequest,
bucket: HistogramBucket,
project_uuid: str,
filter_sql: sql.Composed | None,
) -> HistogramBucket:
"""Calculate the metric for a single bucket.
) -> list[HistogramBucket]:
"""Calculate the metric and count for a column.
Args:
request (HistogramRequest): the request object.
Expand All @@ -106,21 +109,124 @@ def histogram_metric_task(
Returns:
HistogramBucket: the bucket with the metric added.
"""
filter_bucket = bucket_filter(col_request.column, bucket)
final_filter = filter_sql
if filter_bucket is not None:
if final_filter is None:
final_filter = filter_bucket
else:
final_filter = final_filter + sql.SQL(" AND ") + filter_bucket
metric = metric_map(request.metric, project_uuid, request.model, final_filter)
return HistogramBucket(
bucket=bucket.bucket,
bucket_end=bucket.bucket_end,
size=metric.size,
metric=metric.metric,
col = col_request.column
if request.metric is None or request.model is None:
return []

col_id = column_id_from_name_and_model(project_uuid, col.name, request.model)
metric_col_id = column_id_from_name_and_model(
project_uuid, request.metric.columns[0], request.model
)

with Database() as db:
if col.data_type == MetadataType.NOMINAL:
unique = db.execute_return(
sql.SQL("SELECT COUNT(DISTINCT {}) FROM {}").format(
sql.Identifier(col_id),
sql.Identifier(project_uuid),
)
)
if len(unique) > 0 and unique[0][0] > 30:
return []
else:
statement = sql.SQL("SELECT {}, AVG({}), COUNT(*) FROM {}").format(
sql.Identifier(col_id),
sql.Identifier(metric_col_id),
sql.Identifier(project_uuid),
)
if filter_sql:
statement = sql.SQL("{} WHERE {} GROUP BY {}").format(
statement, filter_sql, sql.Identifier(col_id)
)
else:
statement = sql.SQL("{} GROUP BY {}").format(
statement, sql.Identifier(col_id)
)
db_res = db.execute_return(statement)
results_map = {r[0]: (r[1], r[2]) for r in db_res}

return [
HistogramBucket(
bucket=b.bucket,
metric=results_map[b.bucket][0]
if b.bucket in results_map
else 0,
size=results_map[b.bucket][1] if b.bucket in results_map else 0,
)
for b in col_request.buckets
]

elif col.data_type == MetadataType.CONTINUOUS:
case_statement = sql.SQL("CASE ")
for i, b in enumerate(col_request.buckets):
case_statement += sql.SQL("WHEN {} >= {} AND {} < {} THEN {} ").format(
sql.Identifier(metric_col_id),
sql.Literal(b.bucket),
sql.Identifier(metric_col_id),
sql.Literal(b.bucket_end),
sql.Literal(i),
)
case_statement += sql.SQL("END AS bucket")
statement = sql.SQL("SELECT {}, AVG({}), COUNT(*) FROM {}").format(
case_statement,
sql.Identifier(metric_col_id),
sql.Identifier(project_uuid),
)

if filter_sql:
statement = sql.SQL("{} WHERE {} GROUP BY bucket").format(
statement, filter_sql
)
else:
statement = sql.SQL("{} GROUP BY bucket").format(statement)
db_res = db.execute_return(statement)
results_map = {int(r[0]): (r[1], r[2]) for r in db_res if r[0] is not None}

return [
HistogramBucket(
bucket=b.bucket,
bucket_end=b.bucket_end,
metric=results_map[i][0] if i in results_map else 0,
size=results_map[i][1] if i in results_map else 0,
)
for i, b in enumerate(col_request.buckets)
]

elif col.data_type == MetadataType.BOOLEAN:
statement = sql.SQL(
"SELECT CASE WHEN {} = TRUE THEN 0 WHEN {} = FALSE"
" THEN 1 END AS bucket, AVG({}), COUNT(*) FROM {}"
).format(
sql.Identifier(col_id),
sql.Identifier(col_id),
sql.Identifier(metric_col_id),
sql.Identifier(project_uuid),
)
if filter_sql:
statement = sql.SQL("{} WHERE {} GROUP BY bucket").format(
statement, filter_sql
)
else:
statement = sql.SQL("{} GROUP BY bucket").format(statement)
res = db.execute_return(statement)

true_res = [r for r in res if r[0] == 0]
false_res = [r for r in res if r[0] == 1]
return [
HistogramBucket(
bucket=True,
size=true_res[0][2] if len(true_res) > 0 else 0,
metric=true_res[0][1] if len(true_res) > 0 else 0,
),
HistogramBucket(
bucket=False,
size=false_res[0][2] if len(false_res) > 0 else 0,
metric=false_res[0][1] if len(false_res) > 0 else 0,
),
]
else:
return []


def histogram_count(
request: HistogramRequest,
Expand Down Expand Up @@ -198,12 +304,7 @@ def histogram_count(
else:
return []
else:
res = []
for b in col_request.buckets:
res.append(
histogram_metric_task(request, col_request, b, project_uuid, filter_sql)
)
return res
return histogram_metric_task(request, col_request, project_uuid, filter_sql)


def histogram_counts(
Expand Down
1 change: 0 additions & 1 deletion frontend/src/lib/api/metadata.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ export async function calculateHistograms(
});
return histograms;
} catch (e) {
requestingHistogramCounts.set(false);
return histograms;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
export let compare: boolean;
let selected = 'list';
let currentResult: Promise<GroupMetric[] | undefined>;
let modelAResult: Promise<GroupMetric[] | undefined>;
let modelBResult: Promise<GroupMetric[] | undefined>;
let currentResult: Promise<GroupMetric[] | undefined> = new Promise(() => undefined);
let modelAResult: Promise<GroupMetric[] | undefined> = new Promise(() => undefined);
let modelBResult: Promise<GroupMetric[] | undefined> = new Promise(() => undefined);
let numberOfInstances = 0;
let viewOptions: Record<string, unknown> | undefined = undefined;
Expand Down
6 changes: 1 addition & 5 deletions frontend/src/lib/components/metadata/SelectionBar.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import { editTag, metric, project } from '$lib/stores';
import type { GroupMetric } from '$lib/zenoapi';
import Button, { Group } from '@smui/button';
import CircularProgress from '@smui/circular-progress';
import ChipsWrapper from './ChipsWrapper.svelte';
export let currentResult: Promise<GroupMetric[] | undefined>;
Expand All @@ -28,13 +27,10 @@
class="flex flex-wrap justify-between w-full items-center py-2.5 border-b border-grey-lighter"
>
<div class="flex">
<div />
<span class="text-grey-dark mr-3">
{$metric ? $metric.name + ':' : ''}
</span>
{#await currentResult}
<CircularProgress style="height: 20px; width: 20px; margin-right:20px" indeterminate />
{:then res}
{#await currentResult then res}
{#if res !== undefined && res.length > 0}
{#if res[0].metric !== undefined && res[0].metric !== null}
<span class="text-primary mr-3">
Expand Down
32 changes: 15 additions & 17 deletions frontend/src/routes/(app)/project/[owner]/[project]/+layout.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,21 @@
export let data;
if ($project === undefined) {
project.set(data.project);
rowsPerPage.set(data.project.samplesPerPage ?? 10);
slices.set(data.slices);
columns.set(data.columns);
models.set(data.models);
metrics.set(data.metrics);
folders.set(data.folders);
tags.set(data.tags);
model.set(data.model);
metric.set(data.metric);
comparisonModel.set(data.comparisonModel);
comparisonColumn.set(data.comparisonColumn);
compareSort.set(data.compareSort);
metricRange.set(data.metricRange);
selections.set(data.selections);
}
project.set(data.project);
rowsPerPage.set(data.project.samplesPerPage ?? 10);
slices.set(data.slices);
columns.set(data.columns);
models.set(data.models);
metrics.set(data.metrics);
folders.set(data.folders);
tags.set(data.tags);
model.set(data.model);
metric.set(data.metric);
comparisonModel.set(data.comparisonModel);
comparisonColumn.set(data.comparisonColumn);
compareSort.set(data.compareSort);
metricRange.set(data.metricRange);
selections.set(data.selections);
model.subscribe((mod) => {
// URL parameters set by selection subscription.
Expand Down

0 comments on commit 311926d

Please sign in to comment.