Skip to content

Commit

Permalink
fix: make histograms work with metric columns that are boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
Sparkier committed Oct 9, 2023
1 parent c8b93b1 commit 729dd43
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions backend/zeno_backend/processing/histogram_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,28 @@ async def histogram_metric_and_count(
col_id = col_id[0]

metric_col_id = None
metric_col_type = None
if calculate_histograms and request.metric is not None:
await db.execute(
sql.SQL(
"SELECT column_id FROM {} "
"SELECT column_id, data_type FROM {} "
"WHERE name = %s AND (model = %s OR model IS NULL);"
).format(sql.Identifier(f"{project_uuid}_column_map")),
[request.metric.columns[0], request.model],
)
metric_col_id = await db.fetchone()
if metric_col_id is None:
return []
metric_col_type = metric_col_id[1]
metric_col_id = metric_col_id[0]

if col.data_type == MetadataType.NOMINAL:
if calculate_histograms and metric_col_id is not None:
statement = sql.SQL("SELECT {}, COUNT(*), AVG({}) FROM {}").format(
sql.Identifier(col_id),
sql.Identifier(metric_col_id),
sql.Identifier(metric_col_id)
if metric_col_type != MetadataType.BOOLEAN
else sql.Identifier(metric_col_id) + sql.SQL("::int"),
sql.Identifier(project_uuid),
)
else:
Expand Down Expand Up @@ -228,7 +232,9 @@ async def histogram_metric_and_count(
if calculate_histograms and metric_col_id is not None:
statement = sql.SQL("SELECT {}, COUNT(*), AVG({}) FROM {}").format(
case_statement,
sql.Identifier(metric_col_id),
sql.Identifier(metric_col_id)
if metric_col_type != MetadataType.BOOLEAN
else sql.Identifier(metric_col_id) + sql.SQL("::int"),
sql.Identifier(project_uuid),
)
else:
Expand Down Expand Up @@ -278,7 +284,9 @@ async def histogram_metric_and_count(
).format(
sql.Identifier(col_id),
sql.Identifier(col_id),
sql.Identifier(metric_col_id),
sql.Identifier(metric_col_id)
if metric_col_type != MetadataType.BOOLEAN
else sql.Identifier(metric_col_id) + sql.SQL("::int"),
sql.Identifier(project_uuid),
)
else:
Expand Down

0 comments on commit 729dd43

Please sign in to comment.