From 8adeefd6c5343d94f84631cdac001804ad9f475f Mon Sep 17 00:00:00 2001 From: Kevin Klein <7267523+kklein@users.noreply.github.com> Date: Mon, 1 Aug 2022 20:33:01 +0200 Subject: [PATCH] Reimplement Kolmogorov Smirnov query logic with sqlalchemy's Language Expression API (#44) --- src/datajudge/constraints/stats.py | 48 ++++-- src/datajudge/db_access.py | 216 ++++++++++++++++++-------- tests/integration/conftest.py | 22 +++ tests/integration/test_integration.py | 85 +++------- tests/integration/test_stats.py | 64 ++++++++ 5 files changed, 293 insertions(+), 142 deletions(-) create mode 100644 tests/integration/test_stats.py diff --git a/src/datajudge/constraints/stats.py b/src/datajudge/constraints/stats.py index 30de222a..0ff3de4b 100644 --- a/src/datajudge/constraints/stats.py +++ b/src/datajudge/constraints/stats.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Optional, Tuple +from typing import List, Optional, Tuple import sqlalchemy as sa @@ -63,59 +63,73 @@ def check_acceptance( def c(alpha: float): return math.sqrt(-math.log(alpha / 2.0 + 1e-10) * 0.5) - return d_statistic <= c(accepted_level) * math.sqrt( + threshold = c(accepted_level) * math.sqrt( (n_samples + m_samples) / (n_samples * m_samples) ) + return d_statistic <= threshold @staticmethod def calculate_statistic( engine, ref1: DataReference, ref2: DataReference, - ) -> Tuple[float, Optional[float], int, int]: + ) -> Tuple[float, Optional[float], int, int, List]: # retrieve test statistic d, as well as sample sizes m and n - d_statistic = db_access.get_ks_2sample( + d_statistic, ks_selections = db_access.get_ks_2sample( engine, ref1, ref2, ) - n_samples, _ = db_access.get_row_count(engine, ref1) - m_samples, _ = db_access.get_row_count(engine, ref2) + n_samples, n_selections = db_access.get_row_count(engine, ref1) + m_samples, m_selections = db_access.get_row_count(engine, ref2) # calculate approximate p-value p_value = KolmogorovSmirnov2Sample.approximate_p_value( d_statistic, n_samples, m_samples ) - return d_statistic, p_value, n_samples, m_samples + selections = n_selections + m_selections + ks_selections + return d_statistic, p_value, n_samples, m_samples, selections def test(self, engine: sa.engine.Engine) -> TestResult: - - # get query selections and column names for target columns - - d_statistic, p_value, n_samples, m_samples = self.calculate_statistic( + ( + d_statistic, + p_value, + n_samples, + m_samples, + selections, + ) = self.calculate_statistic( engine, self.ref, self.ref2, ) - - # calculate test acceptance result = self.check_acceptance( d_statistic, n_samples, m_samples, self.significance_level ) assertion_text = ( f"Null hypothesis (H0) for the 2-sample Kolmogorov-Smirnov test was rejected, i.e., " - f"the two samples ({self.ref.get_string()} and {self.target_prefix})" - f" do not originate from the same distribution." + f"the two samples ({self.ref.get_string()} and {self.target_prefix}) " + f"do not originate from the same distribution. " f"The test results are d={d_statistic}" ) if p_value is not None: - assertion_text += f"and {p_value=}" + assertion_text += f" and {p_value=}" + assertion_text += "." + + if selections: + queries = [ + str(selection.compile(engine, compile_kwargs={"literal_binds": True})) + for selection in selections + ] if not result: - return TestResult.failure(assertion_text) + return TestResult.failure( + assertion_text, + self.get_description(), + queries, + ) return TestResult.success() diff --git a/src/datajudge/db_access.py b/src/datajudge/db_access.py index 9d0172d0..8bc229ba 100644 --- a/src/datajudge/db_access.py +++ b/src/datajudge/db_access.py @@ -288,7 +288,13 @@ def get_column(self, engine): f"Trying to access column of DataReference " f"{self.get_string()} yet none is given." ) - return self.get_columns(engine)[0] + columns = self.get_columns(engine) + if len(columns) > 1: + raise ValueError( + "DataReference was expected to only have a single column but had multiple: " + f"{columns}" + ) + return columns[0] def get_columns(self, engine): """Fetch all relevant columns of a DataReference.""" @@ -904,77 +910,157 @@ def get_column_array_agg( return result, selections +def _cdf_selection(engine, ref: DataReference, cdf_label: str, value_label: str): + """Create an empirical cumulative distribution function values. + + Concretely, create a selection with values from ``value_label`` as well as + the empirical cumulative didistribution function values, labeled as + ``cdf_label``. + """ + col = ref.get_column(engine) + selection = ref.get_selection(engine).subquery() + + # Step 1: Calculate the CDF over the value column. + cdf_selection = sa.select( + [ + selection.c[col].label(value_label), + sa.func.cume_dist().over(order_by=col).label(cdf_label), + ] + ).subquery() + + # Step 2: Aggregate rows s.t. every value occurs only once. + grouped_cdf_selection = ( + sa.select( + [ + cdf_selection.c[value_label], + sa.func.max(cdf_selection.c[cdf_label]).label(cdf_label), + ] + ) + .group_by(cdf_selection.c[value_label]) + .subquery() + ) + return grouped_cdf_selection + + +def _cross_cdf_selection( + engine, ref1: DataReference, ref2: DataReference, cdf_label: str, value_label: str +): + """Create a cross cumulative distribution function selection given two samples. + + Concretely, both ``DataReference``s are expected to have specified a single relevant column. + This function will generate a selection with rows of the kind ``(value, cdf1(value), cdf2(value))``, + where ``cdf1`` is the cumulative distribution function of ``ref1`` and ``cdf2`` of ``ref2``. + + E.g. if ``ref`` is a reference to a table's column with values ``[1, 1, 3, 2]``, and ``ref2`` is + a reference to a table's column with values ``[2, 5, 4]``, executing the returned selection should + yield a table of the following kind: ``[(1, .5, 0), (2, .75, 1/3), (3, 1 ,1/3), (4, 1, 2/3), (5, 1, 1)]``. + """ + cdf_label1 = cdf_label + "1" + cdf_label2 = cdf_label + "2" + group_label1 = "_grp1" + group_label2 = "_grp2" + + cdf_selection1 = _cdf_selection(engine, ref1, cdf_label, value_label) + cdf_selection2 = _cdf_selection(engine, ref2, cdf_label, value_label) + + # Step 3: Combine the cdfs. + cross_cdf = ( + sa.select( + sa.func.coalesce( + cdf_selection1.c[value_label], cdf_selection2.c[value_label] + ).label(value_label), + cdf_selection1.c[cdf_label].label(cdf_label1), + cdf_selection2.c[cdf_label].label(cdf_label2), + ) + .select_from( + cdf_selection1.join( + cdf_selection2, + cdf_selection1.c[value_label] == cdf_selection2.c[value_label], + isouter=True, + full=True, + ) + ) + .subquery() + ) + + def _cdf_index_column(table, value_label, cdf_label, group_label): + return ( + sa.func.count(table.c[cdf_label]) + .over(order_by=table.c[value_label]) + .label(group_label) + ) + + # Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling. + # In other words, we point rows to their most recent present value - per sample. This is necessary + # Due to the nature of the full outer join. + indexed_cross_cdf = sa.select( + [ + cross_cdf.c[value_label], + _cdf_index_column(cross_cdf, value_label, cdf_label1, group_label1), + cross_cdf.c[cdf_label1], + _cdf_index_column(cross_cdf, value_label, cdf_label2, group_label2), + cross_cdf.c[cdf_label2], + ] + ).subquery() + + def _forward_filled_cdf_column(table, cdf_label, value_label, group_label): + return ( + # Step 6: Replace NULL values at the beginning with 0 to enable computation of difference. + sa.func.coalesce( + ( + # Step 5: Forward-Filling: Select first non-NULL value per group (defined in the prev. step). + sa.func.first_value(table.c[cdf_label]).over( + partition_by=table.c[group_label], order_by=table.c[value_label] + ) + ), + 0, + ).label(cdf_label) + ) + + filled_cross_cdf = sa.select( + [ + indexed_cross_cdf.c[value_label], + _forward_filled_cdf_column( + indexed_cross_cdf, cdf_label1, value_label, group_label1 + ), + _forward_filled_cdf_column( + indexed_cross_cdf, cdf_label2, value_label, group_label2 + ), + ] + ) + return filled_cross_cdf, cdf_label1, cdf_label2 + + def get_ks_2sample( engine: sa.engine.Engine, ref1: DataReference, ref2: DataReference, -) -> float: - """ - Runs the query for the two-sample Kolmogorov-Smirnov test and returns the test statistic d. +): """ - # For mssql: "tempdb.dbo".table_name -> tempdb.dbo.table_name - table1_str = str(ref1.data_source.get_clause(engine)).replace('"', "") - col1 = ref1.get_column(engine) - table2_str = str(ref2.data_source.get_clause(engine)).replace('"', "") - col2 = ref2.get_column(engine) - - # for a more extensive explanation, see: - # https://github.com/Quantco/datajudge/pull/28#issuecomment-1165587929 - ks_query_string = f""" - WITH - tab1 AS ( -- Step 0: Prepare data source and value column - SELECT {col1} as val FROM {table1_str} - ), - tab2 AS ( - SELECT {col2} as val FROM {table2_str} - ), - tab1_cdf AS ( -- Step 1: Calculate the CDF over the value column - SELECT val, cume_dist() over (order by val) as cdf - FROM tab1 - ), - tab2_cdf AS ( - SELECT val, cume_dist() over (order by val) as cdf - FROM tab2 - ), - tab1_grouped AS ( -- Step 2: Remove unnecessary values, s.t. we have (x, cdf(x)) rows only - SELECT val, MAX(cdf) as cdf - FROM tab1_cdf - GROUP BY val - ), - tab2_grouped AS ( - SELECT val, MAX(cdf) as cdf - FROM tab2_cdf - GROUP BY val - ), - joined_cdf AS ( -- Step 3: combine the cdfs - SELECT coalesce(tab1_grouped.val, tab2_grouped.val) as v, tab1_grouped.cdf as cdf1, tab2_grouped.cdf as cdf2 - FROM tab1_grouped FULL OUTER JOIN tab2_grouped ON tab1_grouped.val = tab2_grouped.val - ), - -- Step 4: Create a grouper id based on the value count; this is just a helper for forward-filling - grouped_cdf AS ( - SELECT v, - COUNT(cdf1) over (order by v) as _grp1, - cdf1, - COUNT(cdf2) over (order by v) as _grp2, - cdf2 - FROM joined_cdf - ), - -- Step 5: Forward-Filling: Select first non-null value per group (defined in the prev. step) - filled_cdf AS ( - SELECT v, - first_value(cdf1) over (partition by _grp1 order by v) as cdf1_filled, - first_value(cdf2) over (partition by _grp2 order by v) as cdf2_filled - FROM grouped_cdf), - -- Step 6: Replace NULL values (at the beginning) with 0 to calculate difference - replaced_nulls AS ( - SELECT coalesce(cdf1_filled, 0) as cdf1, coalesce(cdf2_filled, 0) as cdf2 - FROM filled_cdf) - -- Step 7: Calculate final statistic as max. distance - SELECT MAX(ABS(cdf1 - cdf2)) FROM replaced_nulls; + Run the query for the two-sample Kolmogorov-Smirnov test and return the test statistic d. + + For a raw-sql version of this query, please see this PR: + https://github.com/Quantco/datajudge/pull/28/ """ + cdf_label = "cdf" + value_label = "val" + filled_cross_cdf_selection, cdf_label1, cdf_label2 = _cross_cdf_selection( + engine, ref1, ref2, cdf_label, value_label + ) + + filled_cross_cdf = filled_cross_cdf_selection.subquery() + + # Step 7: Calculate final statistic: maximal distance. + final_selection = sa.select( + sa.func.max( + sa.func.abs(filled_cross_cdf.c[cdf_label1] - filled_cross_cdf.c[cdf_label2]) + ) + ) + + with engine.connect() as connection: + d_statistic = connection.execute(final_selection).scalar() - d_statistic = engine.execute(ks_query_string).scalar() - return d_statistic + return d_statistic, [final_selection] def get_regex_violations(engine, ref, aggregated, regex, n_counterexamples): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 299953d9..8e63f5e7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -715,6 +715,28 @@ def capitalization_table(engine, metadata): return TEST_DB_NAME, SCHEMA, table_name, uppercase_column, lowercase_column +@pytest.fixture(scope="module") +def cross_cdf_table1(engine, metadata): + table_name = "cross_cdf_table1" + col_name = "col_int" + columns = [sa.Column(col_name, sa.Integer())] + col_values = [1, 1, 3, 2] + data = [{col_name: col_value} for col_value in col_values] + _handle_table(engine, metadata, table_name, columns, data) + return TEST_DB_NAME, SCHEMA, table_name + + +@pytest.fixture(scope="module") +def cross_cdf_table2(engine, metadata): + table_name = "cross_cdf_table2" + col_name = "col_int" + columns = [sa.Column(col_name, sa.Integer())] + col_values = [3, 5, 4, 5, 8] + data = [{col_name: col_value} for col_value in col_values] + _handle_table(engine, metadata, table_name, columns, data) + return TEST_DB_NAME, SCHEMA, table_name + + def pytest_addoption(parser): parser.addoption( "--backend", diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index e340c5e9..490c758f 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -3,15 +3,7 @@ import pytest import datajudge.requirements as requirements -from datajudge.constraints.stats import KolmogorovSmirnov2Sample -from datajudge.db_access import ( - Condition, - DataReference, - TableDataSource, - is_mssql, - is_postgresql, - is_snowflake, -) +from datajudge.db_access import Condition, is_mssql, is_postgresql, is_snowflake def skip_if_mssql(engine): @@ -1890,10 +1882,6 @@ def test_groupby_aggregation_within_with_failures( assert operation(test_result.outcome), test_result.failure_message -def test_diff_average_between(): - return - - @pytest.mark.parametrize( "data", [ @@ -1902,8 +1890,8 @@ def test_diff_average_between(): identity, "col_int", "col_int", - Condition("col_int >= 3"), - Condition("col_int >= 3"), + Condition(raw_string="col_int >= 3"), + Condition(raw_string="col_int >= 3"), 1.0, ), ], @@ -1925,31 +1913,41 @@ def test_ks_2sample_constraint_perfect_between(engine, int_table1, data): assert operation(test_result.outcome), test_result.failure_message -# TODO: Enable this test once the bug is fixed. -@pytest.mark.skip(reason="This is a known bug and unintended behaviour.") @pytest.mark.parametrize( - "data", + "condition1, condition2", [ - (negation, "col_int", "col_int", None, Condition("col_int >= 10"), 1.0), + ( + None, + Condition(raw_string="col_int >= 10"), + ), + ( + Condition(raw_string="col_int >= 10"), + None, + ), + ( + Condition(raw_string="col_int >= 10"), + Condition(raw_string="col_int >= 3"), + ), ], ) -def test_ks_2sample_constraint_perfect_between_different_condition( - engine, int_table1, data +def test_ks_2sample_constraint_perfect_between_different_conditions( + engine, int_table1, condition1, condition2 ): """ - Test Kolmogorov-Smirnov for the same column -> p-value should be perfect 1.0. + Test Kolmogorov-Smirnov for the same column but different conditions. + As a consequence, since the data is distinct, the tests are expected + to fail for a very high significance level. """ - (operation, col_1, col_2, condition1, condition2, significance_level) = data req = requirements.BetweenRequirement.from_tables(*int_table1, *int_table1) req.add_ks_2sample_constraint( - column1=col_1, - column2=col_2, + column1="col_int", + column2="col_int", condition1=condition1, condition2=condition2, - significance_level=significance_level, + significance_level=1.0, ) test_result = req[0].test(engine) - assert operation(test_result.outcome), test_result.failure_message + assert negation(test_result.outcome), test_result.failure_message @pytest.mark.parametrize( @@ -1999,36 +1997,3 @@ def test_ks_2sample_random(engine, random_normal_table, configuration): ) test_result = req[0].test(engine) assert operation(test_result.outcome), test_result.failure_message - - -@pytest.mark.parametrize( - "configuration", - [ # these values were calculated using scipy.stats.ks_2samp on scipy=1.8.1 - ("value_0_1", "value_0_1", 0.0, 1.0), - ("value_0_1", "value_005_1", 0.0294, 0.00035221594346540835), - ("value_0_1", "value_02_1", 0.0829, 2.6408848561586672e-30), - ("value_0_1", "value_1_1", 0.3924, 0.0), - ], -) -def test_ks_2sample_implementation(engine, random_normal_table, configuration): - col_1, col_2, expected_d, expected_p = configuration - database, schema, table = random_normal_table - tds = TableDataSource(database, table, schema) - ref = DataReference(tds, columns=[col_1]) - ref2 = DataReference(tds, columns=[col_2]) - - ( - d_statistic, - p_value, - n_samples, - m_samples, - ) = KolmogorovSmirnov2Sample.calculate_statistic(engine, ref, ref2) - - assert ( - abs(d_statistic - expected_d) <= 1e-10 - ), f"The test statistic does not match: {expected_d} vs {d_statistic}" - - # 1e-05 should cover common p_values; if scipy is installed, a very accurate p_value is automatically calculated - assert ( - abs(p_value - expected_p) <= 1e-05 - ), f"The approx. p-value does not match: {expected_p} vs {p_value}" diff --git a/tests/integration/test_stats.py b/tests/integration/test_stats.py new file mode 100644 index 00000000..30d1fb1e --- /dev/null +++ b/tests/integration/test_stats.py @@ -0,0 +1,64 @@ +import pytest + +import datajudge +from datajudge.db_access import DataReference, TableDataSource + + +def test_cross_cdf_selection(engine, cross_cdf_table1, cross_cdf_table2): + database1, schema1, table1 = cross_cdf_table1 + database2, schema2, table2 = cross_cdf_table2 + tds1 = TableDataSource(database1, table1, schema1) + tds2 = TableDataSource(database2, table2, schema2) + ref1 = DataReference(tds1, columns=["col_int"]) + ref2 = DataReference(tds2, columns=["col_int"]) + selection, _, _ = datajudge.db_access._cross_cdf_selection( + engine, ref1, ref2, "cdf", "value" + ) + with engine.connect() as connection: + result = connection.execute(selection).fetchall() + assert result is not None and len(result) > 0 + expected_result = [ + (1, 2 / 4, 0), + (2, 3 / 4, 0), + (3, 1, 1 / 5), + (4, 1, 2 / 5), + (5, 1, 4 / 5), + (8, 1, 1), + ] + assert sorted(result) == expected_result + + +@pytest.mark.parametrize( + "configuration", + [ # these values were calculated using scipy.stats.ks_2samp on scipy=1.8.1 + ("value_0_1", "value_0_1", 0.0, 1.0), + ("value_0_1", "value_005_1", 0.0294, 0.00035221594346540835), + ("value_0_1", "value_02_1", 0.0829, 2.6408848561586672e-30), + ("value_0_1", "value_1_1", 0.3924, 0.0), + ], +) +def test_ks_2sample_calculate_statistic(engine, random_normal_table, configuration): + col_1, col_2, expected_d, expected_p = configuration + database, schema, table = random_normal_table + tds = TableDataSource(database, table, schema) + ref = DataReference(tds, columns=[col_1]) + ref2 = DataReference(tds, columns=[col_2]) + + ( + d_statistic, + p_value, + n_samples, + m_samples, + _, + ) = datajudge.constraints.stats.KolmogorovSmirnov2Sample.calculate_statistic( + engine, ref, ref2 + ) + + assert ( + abs(d_statistic - expected_d) <= 1e-10 + ), f"The test statistic does not match: {expected_d} vs {d_statistic}" + + # 1e-05 should cover common p_values; if scipy is installed, a very accurate p_value is automatically calculated + assert ( + abs(p_value - expected_p) <= 1e-05 + ), f"The approx. p-value does not match: {expected_p} vs {p_value}"