Skip to content

Commit

Permalink
Reimplement Kolmogorov Smirnov query logic with sqlalchemy's Language…
Browse files Browse the repository at this point in the history
… Expression API (#44)
  • Loading branch information
kklein authored Aug 1, 2022
1 parent da059c5 commit 8adeefd
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 142 deletions.
48 changes: 31 additions & 17 deletions src/datajudge/constraints/stats.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import sqlalchemy as sa

Expand Down Expand Up @@ -63,59 +63,73 @@ def check_acceptance(
def c(alpha: float):
return math.sqrt(-math.log(alpha / 2.0 + 1e-10) * 0.5)

return d_statistic <= c(accepted_level) * math.sqrt(
threshold = c(accepted_level) * math.sqrt(
(n_samples + m_samples) / (n_samples * m_samples)
)
return d_statistic <= threshold

@staticmethod
def calculate_statistic(
engine,
ref1: DataReference,
ref2: DataReference,
) -> Tuple[float, Optional[float], int, int]:
) -> Tuple[float, Optional[float], int, int, List]:

# retrieve test statistic d, as well as sample sizes m and n
d_statistic = db_access.get_ks_2sample(
d_statistic, ks_selections = db_access.get_ks_2sample(
engine,
ref1,
ref2,
)

n_samples, _ = db_access.get_row_count(engine, ref1)
m_samples, _ = db_access.get_row_count(engine, ref2)
n_samples, n_selections = db_access.get_row_count(engine, ref1)
m_samples, m_selections = db_access.get_row_count(engine, ref2)

# calculate approximate p-value
p_value = KolmogorovSmirnov2Sample.approximate_p_value(
d_statistic, n_samples, m_samples
)

return d_statistic, p_value, n_samples, m_samples
selections = n_selections + m_selections + ks_selections
return d_statistic, p_value, n_samples, m_samples, selections

def test(self, engine: sa.engine.Engine) -> TestResult:

# get query selections and column names for target columns

d_statistic, p_value, n_samples, m_samples = self.calculate_statistic(
(
d_statistic,
p_value,
n_samples,
m_samples,
selections,
) = self.calculate_statistic(
engine,
self.ref,
self.ref2,
)

# calculate test acceptance
result = self.check_acceptance(
d_statistic, n_samples, m_samples, self.significance_level
)

assertion_text = (
f"Null hypothesis (H0) for the 2-sample Kolmogorov-Smirnov test was rejected, i.e., "
f"the two samples ({self.ref.get_string()} and {self.target_prefix})"
f" do not originate from the same distribution."
f"the two samples ({self.ref.get_string()} and {self.target_prefix}) "
f"do not originate from the same distribution. "
f"The test results are d={d_statistic}"
)
if p_value is not None:
assertion_text += f"and {p_value=}"
assertion_text += f" and {p_value=}"
assertion_text += "."

if selections:
queries = [
str(selection.compile(engine, compile_kwargs={"literal_binds": True}))
for selection in selections
]

if not result:
return TestResult.failure(assertion_text)
return TestResult.failure(
assertion_text,
self.get_description(),
queries,
)

return TestResult.success()
216 changes: 151 additions & 65 deletions src/datajudge/db_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,13 @@ def get_column(self, engine):
f"Trying to access column of DataReference "
f"{self.get_string()} yet none is given."
)
return self.get_columns(engine)[0]
columns = self.get_columns(engine)
if len(columns) > 1:
raise ValueError(
"DataReference was expected to only have a single column but had multiple: "
f"{columns}"
)
return columns[0]

def get_columns(self, engine):
"""Fetch all relevant columns of a DataReference."""
Expand Down Expand Up @@ -904,77 +910,157 @@ def get_column_array_agg(
return result, selections


def _cdf_selection(engine, ref: DataReference, cdf_label: str, value_label: str):
"""Create an empirical cumulative distribution function values.
Concretely, create a selection with values from ``value_label`` as well as
the empirical cumulative didistribution function values, labeled as
``cdf_label``.
"""
col = ref.get_column(engine)
selection = ref.get_selection(engine).subquery()

# Step 1: Calculate the CDF over the value column.
cdf_selection = sa.select(
[
selection.c[col].label(value_label),
sa.func.cume_dist().over(order_by=col).label(cdf_label),
]
).subquery()

# Step 2: Aggregate rows s.t. every value occurs only once.
grouped_cdf_selection = (
sa.select(
[
cdf_selection.c[value_label],
sa.func.max(cdf_selection.c[cdf_label]).label(cdf_label),
]
)
.group_by(cdf_selection.c[value_label])
.subquery()
)
return grouped_cdf_selection


def _cross_cdf_selection(
engine, ref1: DataReference, ref2: DataReference, cdf_label: str, value_label: str
):
"""Create a cross cumulative distribution function selection given two samples.
Concretely, both ``DataReference``s are expected to have specified a single relevant column.
This function will generate a selection with rows of the kind ``(value, cdf1(value), cdf2(value))``,
where ``cdf1`` is the cumulative distribution function of ``ref1`` and ``cdf2`` of ``ref2``.
E.g. if ``ref`` is a reference to a table's column with values ``[1, 1, 3, 2]``, and ``ref2`` is
a reference to a table's column with values ``[2, 5, 4]``, executing the returned selection should
yield a table of the following kind: ``[(1, .5, 0), (2, .75, 1/3), (3, 1 ,1/3), (4, 1, 2/3), (5, 1, 1)]``.
"""
cdf_label1 = cdf_label + "1"
cdf_label2 = cdf_label + "2"
group_label1 = "_grp1"
group_label2 = "_grp2"

cdf_selection1 = _cdf_selection(engine, ref1, cdf_label, value_label)
cdf_selection2 = _cdf_selection(engine, ref2, cdf_label, value_label)

# Step 3: Combine the cdfs.
cross_cdf = (
sa.select(
sa.func.coalesce(
cdf_selection1.c[value_label], cdf_selection2.c[value_label]
).label(value_label),
cdf_selection1.c[cdf_label].label(cdf_label1),
cdf_selection2.c[cdf_label].label(cdf_label2),
)
.select_from(
cdf_selection1.join(
cdf_selection2,
cdf_selection1.c[value_label] == cdf_selection2.c[value_label],
isouter=True,
full=True,
)
)
.subquery()
)

def _cdf_index_column(table, value_label, cdf_label, group_label):
return (
sa.func.count(table.c[cdf_label])
.over(order_by=table.c[value_label])
.label(group_label)
)

# Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling.
# In other words, we point rows to their most recent present value - per sample. This is necessary
# Due to the nature of the full outer join.
indexed_cross_cdf = sa.select(
[
cross_cdf.c[value_label],
_cdf_index_column(cross_cdf, value_label, cdf_label1, group_label1),
cross_cdf.c[cdf_label1],
_cdf_index_column(cross_cdf, value_label, cdf_label2, group_label2),
cross_cdf.c[cdf_label2],
]
).subquery()

def _forward_filled_cdf_column(table, cdf_label, value_label, group_label):
return (
# Step 6: Replace NULL values at the beginning with 0 to enable computation of difference.
sa.func.coalesce(
(
# Step 5: Forward-Filling: Select first non-NULL value per group (defined in the prev. step).
sa.func.first_value(table.c[cdf_label]).over(
partition_by=table.c[group_label], order_by=table.c[value_label]
)
),
0,
).label(cdf_label)
)

filled_cross_cdf = sa.select(
[
indexed_cross_cdf.c[value_label],
_forward_filled_cdf_column(
indexed_cross_cdf, cdf_label1, value_label, group_label1
),
_forward_filled_cdf_column(
indexed_cross_cdf, cdf_label2, value_label, group_label2
),
]
)
return filled_cross_cdf, cdf_label1, cdf_label2


def get_ks_2sample(
engine: sa.engine.Engine,
ref1: DataReference,
ref2: DataReference,
) -> float:
"""
Runs the query for the two-sample Kolmogorov-Smirnov test and returns the test statistic d.
):
"""
# For mssql: "tempdb.dbo".table_name -> tempdb.dbo.table_name
table1_str = str(ref1.data_source.get_clause(engine)).replace('"', "")
col1 = ref1.get_column(engine)
table2_str = str(ref2.data_source.get_clause(engine)).replace('"', "")
col2 = ref2.get_column(engine)

# for a more extensive explanation, see:
# https://github.com/Quantco/datajudge/pull/28#issuecomment-1165587929
ks_query_string = f"""
WITH
tab1 AS ( -- Step 0: Prepare data source and value column
SELECT {col1} as val FROM {table1_str}
),
tab2 AS (
SELECT {col2} as val FROM {table2_str}
),
tab1_cdf AS ( -- Step 1: Calculate the CDF over the value column
SELECT val, cume_dist() over (order by val) as cdf
FROM tab1
),
tab2_cdf AS (
SELECT val, cume_dist() over (order by val) as cdf
FROM tab2
),
tab1_grouped AS ( -- Step 2: Remove unnecessary values, s.t. we have (x, cdf(x)) rows only
SELECT val, MAX(cdf) as cdf
FROM tab1_cdf
GROUP BY val
),
tab2_grouped AS (
SELECT val, MAX(cdf) as cdf
FROM tab2_cdf
GROUP BY val
),
joined_cdf AS ( -- Step 3: combine the cdfs
SELECT coalesce(tab1_grouped.val, tab2_grouped.val) as v, tab1_grouped.cdf as cdf1, tab2_grouped.cdf as cdf2
FROM tab1_grouped FULL OUTER JOIN tab2_grouped ON tab1_grouped.val = tab2_grouped.val
),
-- Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling
grouped_cdf AS (
SELECT v,
COUNT(cdf1) over (order by v) as _grp1,
cdf1,
COUNT(cdf2) over (order by v) as _grp2,
cdf2
FROM joined_cdf
),
-- Step 5: Forward-Filling: Select first non-null value per group (defined in the prev. step)
filled_cdf AS (
SELECT v,
first_value(cdf1) over (partition by _grp1 order by v) as cdf1_filled,
first_value(cdf2) over (partition by _grp2 order by v) as cdf2_filled
FROM grouped_cdf),
-- Step 6: Replace NULL values (at the beginning) with 0 to calculate difference
replaced_nulls AS (
SELECT coalesce(cdf1_filled, 0) as cdf1, coalesce(cdf2_filled, 0) as cdf2
FROM filled_cdf)
-- Step 7: Calculate final statistic as max. distance
SELECT MAX(ABS(cdf1 - cdf2)) FROM replaced_nulls;
Run the query for the two-sample Kolmogorov-Smirnov test and return the test statistic d.
For a raw-sql version of this query, please see this PR:
https://github.com/Quantco/datajudge/pull/28/
"""
cdf_label = "cdf"
value_label = "val"
filled_cross_cdf_selection, cdf_label1, cdf_label2 = _cross_cdf_selection(
engine, ref1, ref2, cdf_label, value_label
)

filled_cross_cdf = filled_cross_cdf_selection.subquery()

# Step 7: Calculate final statistic: maximal distance.
final_selection = sa.select(
sa.func.max(
sa.func.abs(filled_cross_cdf.c[cdf_label1] - filled_cross_cdf.c[cdf_label2])
)
)

with engine.connect() as connection:
d_statistic = connection.execute(final_selection).scalar()

d_statistic = engine.execute(ks_query_string).scalar()
return d_statistic
return d_statistic, [final_selection]


def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples):
Expand Down
22 changes: 22 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,28 @@ def capitalization_table(engine, metadata):
return TEST_DB_NAME, SCHEMA, table_name, uppercase_column, lowercase_column


@pytest.fixture(scope="module")
def cross_cdf_table1(engine, metadata):
table_name = "cross_cdf_table1"
col_name = "col_int"
columns = [sa.Column(col_name, sa.Integer())]
col_values = [1, 1, 3, 2]
data = [{col_name: col_value} for col_value in col_values]
_handle_table(engine, metadata, table_name, columns, data)
return TEST_DB_NAME, SCHEMA, table_name


@pytest.fixture(scope="module")
def cross_cdf_table2(engine, metadata):
table_name = "cross_cdf_table2"
col_name = "col_int"
columns = [sa.Column(col_name, sa.Integer())]
col_values = [3, 5, 4, 5, 8]
data = [{col_name: col_value} for col_value in col_values]
_handle_table(engine, metadata, table_name, columns, data)
return TEST_DB_NAME, SCHEMA, table_name


def pytest_addoption(parser):
parser.addoption(
"--backend",
Expand Down
Loading

0 comments on commit 8adeefd

Please sign in to comment.