Skip to content

Commit

Permalink
Implement Kolmogorov Smirnov Test in SQL-only (#28)
Browse files Browse the repository at this point in the history
## Change
This PR implements the Kolmogorov Smirnov test in pure SQL which is then run directly on the database.

## Commits
* first integration: ks-test in database functionality

* integrate sql query with data refs

* formatting

* formatting

* refactoring: call `test` directly to access sql-result of KS test

* fix row count retrieval

* fix acceptance level domain error

* fix alpha adjustment

* fix type hints for python<3.10

* update sql query for postgres: all tables need to have an alias assigned to them

* fix: typo

* update query for mssql server

* add check for column names

* alternative way of getting table name, incl. hot fix for mssql quotation marks in table reference

* don't accept zero alphas since in practice they don't make much sense

* update variable naming and doc-strings

* update data retrieval

* include query nesting brackets

* better formatting for understandibility

* better formatting for understandibility

* update query for better readibility with more WITH statements

* new option of passing values to the TestResult to compare these

* seperate implementation testing from use case testing

* make independent of numpy

* update tests: new distributions, no scipy and numpy dependency, random numbers generated from seed for reproducability

* update comment

* optional accuracy through scipy

* refactoring, clean up and formatting

* update comment and type hints

* update tpye hints for older python versions

* fix type hint: Tuple instead of tuple

* update changelog and include comment about scipy calculation
  • Loading branch information
YYYasin19 authored and kklein committed Jul 25, 2022
1 parent 1234eee commit 60815b3
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 42 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
Changelog
=========

1.1.1 - 2022.06.30
------------------

**New: SQL implementation for KS-test**

- The Kolgomorov Smirnov test is now implemented in pure SQL, shifting the computation to the database engine, improving performance tremendously.

1.1.0 - 2022.06.01
------------------

Expand Down
122 changes: 92 additions & 30 deletions src/datajudge/constraints/stats.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any, Collection, Optional, Tuple
import math
import warnings
from typing import Optional, Tuple, Union

import sqlalchemy as sa
from sqlalchemy.sql import Selectable

from .. import db_access
from ..db_access import DataReference
from .base import Constraint, OptionalSelections
from .base import Constraint, TestResult


class KolmogorovSmirnov2Sample(Constraint):
Expand All @@ -15,43 +18,102 @@ def __init__(
super().__init__(ref, ref2=ref2)

@staticmethod
def calculate_2sample_ks_test(data: Collection, data2: Collection) -> float:
def approximate_p_value(
d: float, n_samples: int, m_samples: int
) -> Optional[float]:
"""
For two given lists of values calculates the Kolmogorov-Smirnov test.
Read more here: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kstest.html
Calculates the approximate p-value according to
'A procedure to find exact critical values of Kolmogorov-Smirnov Test', Silvia Fachinetti, 2009
Note: For environments with `scipy` installed, this method will return a quasi-exact p-value.
"""

# approximation does not work for small sample sizes
samples = min(n_samples, m_samples)
if samples < 35:
warnings.warn(
"Approximating the p-value is not accurate enough for sample size < 35"
)
return None

# if scipy is installed, accurately calculate the p_value using the full distribution
try:
from scipy.stats import ks_2samp
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Calculating the Kolmogorov-Smirnov test relies on scipy."
"Therefore, please install scipy before using this test."
from scipy.stats.distributions import kstwo

approx_p = kstwo.sf(
d, round((n_samples * m_samples) / (n_samples + m_samples))
)
except ModuleNotFoundError:
d_alpha = d * math.sqrt(samples)
approx_p = 2 * math.exp(-(d_alpha**2))

# clamp value to [0, 1]
return 1.0 if approx_p > 1.0 else 0.0 if approx_p < 0.0 else approx_p

@staticmethod
def check_acceptance(
d_statistic: float, n_samples: int, m_samples: int, accepted_level: float
) -> bool:
"""
For a given test statistic, d, and the respective sample sizes `n` and `m`, this function
checks whether the null hypothesis can be rejected for an accepted significance level.
For more information, check out the `Wikipedia entry <https://w.wiki/5May>`_.
"""

# Currently, the calculation will be performed locally through scipy
# In future versions, an implementation where either the database engine
# (1) calculates the CDF
# or even (2) calculates the KS test
# can be expected
statistic, p_value = ks_2samp(data, data2)
def c(alpha: float):
return math.sqrt(-math.log(alpha / 2.0 + 1e-10) * 0.5)

return p_value
return d_statistic <= c(accepted_level) * math.sqrt(
(n_samples + m_samples) / (n_samples * m_samples)
)

@staticmethod
def calculate_statistic(
engine,
table1_def: Tuple[Union[Selectable, str], str],
table2_def: Tuple[Union[Selectable, str], str],
) -> Tuple[float, Optional[float], int, int]:

# retrieve test statistic d, as well as sample sizes m and n
d_statistic, n_samples, m_samples = db_access.get_ks_2sample(
engine, table1=table1_def, table2=table2_def
)

# calculate approximate p-value
p_value = KolmogorovSmirnov2Sample.approximate_p_value(
d_statistic, n_samples, m_samples
)

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Any, OptionalSelections]:
return db_access.get_column(engine, ref)
return d_statistic, p_value, n_samples, m_samples

def compare(
self, value_factual: Any, value_target: Any
) -> Tuple[bool, Optional[str]]:
def test(self, engine: sa.engine.Engine) -> TestResult:

# get query selections and column names for target columns
selection1 = self.ref.data_source.get_clause(engine)
column1 = self.ref.get_column(engine)
selection2 = self.ref2.data_source.get_clause(engine)
column2 = self.ref2.get_column(engine)

d_statistic, p_value, n_samples, m_samples = self.calculate_statistic(
engine, (selection1, column1), (selection2, column2)
)

# calculate test acceptance
result = self.check_acceptance(
d_statistic, n_samples, m_samples, self.significance_level
)

p_value = self.calculate_2sample_ks_test(value_factual, value_target)
result = p_value >= self.significance_level
assertion_text = (
f"2-Sample Kolmogorov-Smirnov between {self.ref.get_string()} and {self.target_prefix}"
f"has p-value {p_value} < {self.significance_level}"
f"{self.condition_string}"
f"Null hypothesis (H0) for the 2-sample Kolmogorov-Smirnov test was rejected, i.e., "
f"the two samples ({self.ref.get_string()} and {self.target_prefix})"
f" do not originate from the same distribution."
f"The test results are d={d_statistic}"
)
if p_value is not None:
assertion_text += f"and {p_value=}"

if not result:
return TestResult.failure(assertion_text)

return result, assertion_text
return TestResult.success()
86 changes: 86 additions & 0 deletions src/datajudge/db_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,3 +902,89 @@ def get_column_array_agg(
for t in result
]
return result, selections


def get_ks_2sample(
engine: sa.engine.Engine, table1: tuple, table2: tuple
) -> tuple[float, int, int]:
"""
Runs the query for the two-sample Kolmogorov-Smirnov test and returns the test statistic d.
"""

# make sure we have a string representation here
table1_selection, col1 = str(table1[0]), str(table1[1])
table2_selection, col2 = str(table2[0]), str(table2[1])

if is_mssql(engine): # "tempdb.dbo".table_name -> tempdb.dbo.table_name
table1_selection = table1_selection.replace('"', "")
table2_selection = table2_selection.replace('"', "")

# for RawQueryDataSource this could be a whole subquery and will therefore need to be wrapped
if "SELECT" in table1_selection:
table1_selection = f"({table1_selection})"
table2_selection = f"({table2_selection})"

# for a more extensive explanation, see:
# https://github.com/Quantco/datajudge/pull/28#issuecomment-1165587929
ks_query_string = f"""
WITH
tab1 AS ( -- Step 0: Prepare data source and value column
SELECT {col1} as val FROM {table1_selection}
),
tab2 AS (
SELECT {col2} as val FROM {table2_selection}
),
tab1_cdf AS ( -- Step 1: Calculate the CDF over the value column
SELECT val, cume_dist() over (order by val) as cdf
FROM tab1
),
tab2_cdf AS (
SELECT val, cume_dist() over (order by val) as cdf
FROM tab2
),
tab1_grouped AS ( -- Step 2: Remove unnecessary values, s.t. we have (x, cdf(x)) rows only
SELECT val, MAX(cdf) as cdf
FROM tab1_cdf
GROUP BY val
),
tab2_grouped AS (
SELECT val, MAX(cdf) as cdf
FROM tab2_cdf
GROUP BY val
),
joined_cdf AS ( -- Step 3: combine the cdfs
SELECT coalesce(tab1_grouped.val, tab2_grouped.val) as v, tab1_grouped.cdf as cdf1, tab2_grouped.cdf as cdf2
FROM tab1_grouped FULL OUTER JOIN tab2_grouped ON tab1_grouped.val = tab2_grouped.val
),
-- Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling
grouped_cdf AS (
SELECT v,
COUNT(cdf1) over (order by v) as _grp1,
cdf1,
COUNT(cdf2) over (order by v) as _grp2,
cdf2
FROM joined_cdf
),
-- Step 5: Forward-Filling: Select first non-null value per group (defined in the prev. step)
filled_cdf AS (
SELECT v,
first_value(cdf1) over (partition by _grp1 order by v) as cdf1_filled,
first_value(cdf2) over (partition by _grp2 order by v) as cdf2_filled
FROM grouped_cdf),
-- Step 6: Replace NULL values (at the beginning) with 0 to calculate difference
replaced_nulls AS (
SELECT coalesce(cdf1_filled, 0) as cdf1, coalesce(cdf2_filled, 0) as cdf2
FROM filled_cdf)
-- Step 7: Calculate final statistic as max. distance
SELECT MAX(ABS(cdf1 - cdf2)) FROM replaced_nulls;
"""

d_statistic = engine.execute(ks_query_string).scalar()
n_samples = engine.execute(
f"SELECT COUNT(*) FROM {table1_selection} as n_table"
).scalar()
m_samples = engine.execute(
f"SELECT COUNT(*) FROM {table2_selection} as m_table"
).scalar()

return d_statistic, n_samples, m_samples
9 changes: 7 additions & 2 deletions src/datajudge/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,9 +1268,14 @@ def add_ks_2sample_constraint(
The signifance_level must be a value between 0.0 and 1.0.
"""

if significance_level < 0.0 or significance_level > 1.0:
if not column1 or not column2:
raise ValueError(
"The requested significance level has to be between 0.0 and 1.0. Default is 0.05."
"Column names have to be given for this test's functionality."
)

if significance_level <= 0.0 or significance_level > 1.0:
raise ValueError(
"The requested significance level has to be in `(0.0, 1.0]`. Default is 0.05."
)

ref = DataReference(self.data_source, [column1], condition=condition1)
Expand Down
32 changes: 32 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import itertools
import os
import random
import urllib.parse

import pytest
Expand Down Expand Up @@ -661,6 +662,37 @@ def groupby_aggregation_table_incorrect(engine, metadata):
return TEST_DB_NAME, SCHEMA, table_name


@pytest.fixture(scope="module")
def random_normal_table(engine, metadata):
"""
Table containing 10_000 randomly distributed values with mean = 0 and std.dev = 1.
"""
table_name = "random_normal_table"
columns = [
sa.Column("value_0_1", sa.Float()),
sa.Column("value_005_1", sa.Float()),
sa.Column("value_02_1", sa.Float()),
sa.Column("value_1_1", sa.Float()),
]
row_size = 10_000
random.seed(0)
rand1 = [random.gauss(0, 1) for _ in range(row_size)]
rand2 = [random.gauss(0.05, 1) for _ in range(row_size)]
rand3 = [random.gauss(0.2, 1) for _ in range(row_size)]
rand4 = [random.gauss(1, 1) for _ in range(row_size)]
data = [
{
"value_0_1": rand1[idx],
"value_005_1": rand2[idx],
"value_02_1": rand3[idx],
"value_1_1": rand4[idx],
}
for idx in range(row_size)
]
_handle_table(engine, metadata, table_name, columns, data)
return TEST_DB_NAME, SCHEMA, table_name


@pytest.fixture(scope="module")
def capitalization_table(engine, metadata):
table_name = "capitalization"
Expand Down
Loading

0 comments on commit 60815b3

Please sign in to comment.