Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Kolmogorov Smirnov Test in SQL-only #28

Merged
merged 32 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8d80f9f
first integration: ks-test in database functionality
YYYasin19 Jun 23, 2022
8c7f25a
integrate sql query with data refs
YYYasin19 Jun 23, 2022
ba802d9
formatting
YYYasin19 Jun 24, 2022
1127da0
formatting
YYYasin19 Jun 24, 2022
61e06b0
refactoring: call `test` directly to access sql-result of KS test
YYYasin19 Jun 24, 2022
b847613
fix row count retrieval
YYYasin19 Jun 24, 2022
1688550
fix acceptance level domain error
YYYasin19 Jun 24, 2022
7b8b7dc
fix alpha adjustment
YYYasin19 Jun 24, 2022
806e2dd
fix type hints for python<3.10
YYYasin19 Jun 24, 2022
e5349ec
update sql query for postgres: all tables need to have an alias assig…
YYYasin19 Jun 24, 2022
c9bf5cb
fix: typo
YYYasin19 Jun 24, 2022
6866429
update query for mssql server
YYYasin19 Jun 24, 2022
b11fcf0
add check for column names
YYYasin19 Jun 25, 2022
a3ff0a6
alternative way of getting table name, incl. hot fix for mssql quotat…
YYYasin19 Jun 25, 2022
38f7dd6
don't accept zero alphas since in practice they don't make much sense
YYYasin19 Jun 27, 2022
b5307c4
update variable naming and doc-strings
YYYasin19 Jun 27, 2022
4ecf804
update data retrieval
YYYasin19 Jun 28, 2022
ecfbd8f
include query nesting brackets
YYYasin19 Jun 28, 2022
989dc99
better formatting for understandibility
YYYasin19 Jun 28, 2022
27d7604
better formatting for understandibility
YYYasin19 Jun 28, 2022
f43e69e
update query for better readibility with more WITH statements
YYYasin19 Jun 28, 2022
370514f
new option of passing values to the TestResult to compare these
YYYasin19 Jun 28, 2022
7fb7106
seperate implementation testing from use case testing
YYYasin19 Jun 29, 2022
395b411
make independent of numpy
YYYasin19 Jun 29, 2022
c1e01ab
update tests: new distributions, no scipy and numpy dependency, rando…
YYYasin19 Jun 29, 2022
8c12e83
update comment
YYYasin19 Jun 29, 2022
b0631f9
optional accuracy through scipy
YYYasin19 Jun 29, 2022
158ed7b
refactoring, clean up and formatting
YYYasin19 Jun 29, 2022
26594f8
update comment and type hints
YYYasin19 Jun 29, 2022
6524fb8
update tpye hints for older python versions
YYYasin19 Jun 29, 2022
10f689c
fix type hint: Tuple instead of tuple
YYYasin19 Jun 29, 2022
023bfad
update changelog and include comment about scipy calculation
YYYasin19 Jun 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 83 additions & 31 deletions src/datajudge/constraints/stats.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Collection, Optional, Tuple
import math
import warnings
from typing import Any, Optional, Tuple

import sqlalchemy as sa

from .. import db_access
from ..db_access import DataReference
from .base import Constraint, OptionalSelections
from .base import Constraint, OptionalSelections, TestResult


class KolmogorovSmirnov2Sample(Constraint):
Expand All @@ -14,44 +16,94 @@ def __init__(
self.significance_level = significance_level
super().__init__(ref, ref2=ref2)

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Any, OptionalSelections]:
sel = ref.get_selection(engine) # table selection incl. WHERE condition
col = ref.get_column(engine) # column name
return sel, col

@staticmethod
def calculate_2sample_ks_test(data: Collection, data2: Collection) -> float:
def approximate_p_value(
d: float, n_samples: int, m_samples: int
) -> Optional[float]:
"""
For two given lists of values calculates the Kolmogorov-Smirnov test.
Read more here: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kstest.html
Calculates the approximate p-value according to
'A procedure to find exact critical values of Kolmogorov-Smirnov Test', Silvia Fachinetti, 2009
"""
try:
from scipy.stats import ks_2samp
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Calculating the Kolmogorov-Smirnov test relies on scipy."
"Therefore, please install scipy before using this test."

# approximation does not work for small sample sizes
samples = min(n_samples, m_samples)
if samples < 35:
warnings.warn(
"Approximating the p-value is not accurate enough for sample size < 35"
)
return None

# Currently, the calculation will be performed locally through scipy
# In future versions, an implementation where either the database engine
# (1) calculates the CDF
# or even (2) calculates the KS test
# can be expected
statistic, p_value = ks_2samp(data, data2)
d_alpha = d * math.sqrt(samples)
approx_p = 2 * math.exp(-(d_alpha**2))

return p_value
# clamp value to [0, 1]
return 1.0 if approx_p > 1.0 else 0.0 if approx_p < 0.0 else approx_p

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Any, OptionalSelections]:
return db_access.get_column(engine, ref)
@staticmethod
def check_acceptance(
d_statistic: float, n_samples: int, m_samples: int, accepted_level: float
):
"""
For a given test statistic, d, and the respective sample sizes `n` and `m`, this function
checks whether the null hypothesis can be rejected for an accepted significance level.

For more information, check out the `Wikipedia entry <https://w.wiki/5May>`.
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
"""

def c(alpha: float):
return math.sqrt(-math.log(alpha / 2.0 + 1e-10) * 0.5)

def compare(
self, value_factual: Any, value_target: Any
) -> Tuple[bool, Optional[str]]:
return d_statistic <= c(accepted_level) * math.sqrt(
(n_samples + m_samples) / (n_samples * m_samples)
)

@staticmethod
def calculate_statistic(engine, table1, table2) -> Any:

# retrieve test statistic d, as well as sample sizes m and n
d_statistic, m, n = db_access.get_ks_2sample(
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
engine, table1=table1, table2=table2
)

# calculate approximate p-value
p_value = KolmogorovSmirnov2Sample.approximate_p_value(d_statistic, m, n)

return d_statistic, p_value, n, m

def test(self, engine: sa.engine.Engine) -> TestResult:

# get query selections and column names for target columns
selection1 = str(self.ref.data_source.get_clause(engine))
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
column1 = self.ref.get_column(engine)
selection2 = str(self.ref2.data_source.get_clause(engine))
column2 = self.ref2.get_column(engine)

d_statistic, p_value, n_samples, m_samples = self.calculate_statistic(
engine, (selection1, column1), (selection2, column2)
)

# calculate test acceptance
result = self.check_acceptance(
d_statistic, n_samples, m_samples, self.significance_level
)

p_value = self.calculate_2sample_ks_test(value_factual, value_target)
result = p_value >= self.significance_level
assertion_text = (
f"2-Sample Kolmogorov-Smirnov between {self.ref.get_string()} and {self.target_prefix}"
f"has p-value {p_value} < {self.significance_level}"
f"{self.condition_string}"
f"Null hypothesis (H0) for the 2-sample Kolmogorov-Smirnov test was rejected, i.e., "
f"the two samples ({self.ref.get_string()} and {self.target_prefix})"
f" do not originate from the same distribution."
)
if p_value:
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
assertion_text += f"\n p-value: {p_value}"

# store values s.t. they can be checked later
if not result:
return TestResult.failure(assertion_text)

return result, assertion_text
return TestResult.success()
82 changes: 82 additions & 0 deletions src/datajudge/db_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,3 +902,85 @@ def get_column_array_agg(
for t in result
]
return result, selections


def get_ks_2sample(engine: sa.engine.Engine, table1: tuple, table2: tuple):
"""
Runs the query for the two-sample Kolmogorov-Smirnov test and returns the test statistic d.
"""
table1_selection, col1 = table1
table2_selection, col2 = table2

if is_mssql(engine):
table1_selection = str(table1_selection).replace(
'"', ""
) # tempdb.dbo.int_table
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
table2_selection = str(table2_selection).replace(
'"', ""
) # "tempdb.dbo".int_table

# for RawQueryDataSource this could be a whole subquery and will therefore need to be wrapped
if "SELECT" in table1_selection:
table1_selection = f"({table1_selection})"
table2_selection = f"({table2_selection})"

# for a more extensive explanation, see:
# https://github.com/Quantco/datajudge/pull/28#issuecomment-1165587929
ks_query_string = f"""
YYYasin19 marked this conversation as resolved.
Show resolved Hide resolved
WITH
tab1 AS ( -- Step 0: Prepare data source and value column
SELECT {col1} as val FROM {table1_selection}
),
tab2 AS (
SELECT {col2} as val FROM {table2_selection}
),
tab1_cdf AS ( -- Step 1: Calculate the CDF over the value column
SELECT val, cume_dist() over (order by val) as cdf
FROM tab1
),
tab2_cdf AS (
SELECT val, cume_dist() over (order by val) as cdf
FROM tab2
),
tab1_grouped AS ( -- Step 2: Remove unnecessary values, s.t. we have (x, cdf(x)) rows only
SELECT val, MAX(cdf) as cdf
FROM tab1_cdf
GROUP BY val
),
tab2_grouped AS (
SELECT val, MAX(cdf) as cdf
FROM tab2_cdf
GROUP BY val
),
joined_cdf AS ( -- Step 3: combine the cdfs
SELECT coalesce(tab1_grouped.val, tab2_grouped.val) as v, tab1_grouped.cdf as cdf1, tab2_grouped.cdf as cdf2
FROM tab1_grouped FULL OUTER JOIN tab2_grouped ON tab1_grouped.val = tab2_grouped.val
),
-- Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling
grouped_cdf AS (
SELECT v,
COUNT(cdf1) over (order by v) as _grp1,
cdf1,
COUNT(cdf2) over (order by v) as _grp2,
cdf2
FROM joined_cdf
),
-- Step 5: Forward-Filling: Select first non-null value per group (defined in the prev. step)
filled_cdf AS (
SELECT v,
first_value(cdf1) over (partition by _grp1 order by v) as cdf1_filled,
first_value(cdf2) over (partition by _grp2 order by v) as cdf2_filled
FROM grouped_cdf),
-- Step 6: Replace NULL values (at the beginning) with 0 to calculate difference
replaced_nulls AS (
SELECT coalesce(cdf1_filled, 0) as cdf1, coalesce(cdf2_filled, 0) as cdf2
FROM filled_cdf)
-- Step 7: Calculate final statistic as max. distance
SELECT MAX(ABS(cdf1 - cdf2)) FROM replaced_nulls;
"""

d_statistic = engine.execute(ks_query_string).scalar()
n = engine.execute(f"SELECT COUNT(*) FROM {table1_selection} as n_table").scalar()
m = engine.execute(f"SELECT COUNT(*) FROM {table2_selection} as m_table").scalar()

return d_statistic, n, m
9 changes: 7 additions & 2 deletions src/datajudge/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,9 +1268,14 @@ def add_ks_2sample_constraint(
The signifance_level must be a value between 0.0 and 1.0.
"""

if significance_level < 0.0 or significance_level > 1.0:
if not column1 or not column2:
raise ValueError(
"The requested significance level has to be between 0.0 and 1.0. Default is 0.05."
"Column names have to be given for this test's functionality."
)

if significance_level <= 0.0 or significance_level > 1.0:
raise ValueError(
"The requested significance level has to be in `(0.0, 1.0]`. Default is 0.05."
)

ref = DataReference(self.data_source, [column1], condition=condition1)
Expand Down
32 changes: 32 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import itertools
import os
import random
import urllib.parse

import pytest
Expand Down Expand Up @@ -661,6 +662,37 @@ def groupby_aggregation_table_incorrect(engine, metadata):
return TEST_DB_NAME, SCHEMA, table_name


@pytest.fixture(scope="module")
def random_normal_table(engine, metadata):
"""
Table containing 10_000 randomly distributed values with mean = 0 and std.dev = 1.
"""
table_name = "random_normal_table"
columns = [
sa.Column("value_0_1", sa.Float()),
sa.Column("value_005_1", sa.Float()),
sa.Column("value_02_1", sa.Float()),
sa.Column("value_1_1", sa.Float()),
]
row_size = 10_000
random.seed(0)
rand1 = [random.gauss(0, 1) for _ in range(row_size)]
rand2 = [random.gauss(0.05, 1) for _ in range(row_size)]
rand3 = [random.gauss(0.2, 1) for _ in range(row_size)]
rand4 = [random.gauss(1, 1) for _ in range(row_size)]
data = [
{
"value_0_1": rand1[idx],
"value_005_1": rand2[idx],
"value_02_1": rand3[idx],
"value_1_1": rand4[idx],
}
for idx in range(row_size)
]
_handle_table(engine, metadata, table_name, columns, data)
return TEST_DB_NAME, SCHEMA, table_name


@pytest.fixture(scope="module")
def capitalization_table(engine, metadata):
table_name = "capitalization"
Expand Down
Loading