From 6ddaf8249054d02250ca92efc36c3550d7a425dd Mon Sep 17 00:00:00 2001 From: Bela Stoyan Date: Thu, 1 Jun 2023 09:23:23 +0200 Subject: [PATCH] extend interval constraints to work with integers (#146) --- CHANGELOG.rst | 2 +- src/datajudge/constraints/date.py | 94 +------------ src/datajudge/constraints/interval.py | 131 +++++++++++++++++ src/datajudge/constraints/numeric.py | 52 ++++++- src/datajudge/db_access.py | 117 ++++++++++++---- src/datajudge/requirements.py | 108 ++++++++++++++ tests/integration/conftest.py | 194 ++++++++++++++++++++++++++ tests/integration/test_integration.py | 81 +++++++++++ 8 files changed, 664 insertions(+), 115 deletions(-) create mode 100644 src/datajudge/constraints/interval.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f85f0068..fe7e0bce 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,7 +14,7 @@ Changelog - Implement :meth:`datajudge.WithinRequirement.add_value_distribution_constraint`. - Extended :meth:`datajudge.WithinRequirement.add_column_type_constraint` to support column type specification using string format, backend-specific SQLAlchemy types, and SQLAlchemy's generic types. - +- Implement :meth:`datajudge.WithinRequirement.add_numeric_no_gap_constraint`, :meth:`datajudge.WithinRequirement.add_numeric_no_overlap_constraint`, 1.6.0 - 2022.04.12 ------------------ diff --git a/src/datajudge/constraints/date.py b/src/datajudge/constraints/date.py index 43140f42..2f1e7141 100644 --- a/src/datajudge/constraints/date.py +++ b/src/datajudge/constraints/date.py @@ -1,12 +1,12 @@ -import abc import datetime as dt -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import sqlalchemy as sa from .. import db_access from ..db_access import DataReference from .base import Constraint, OptionalSelections, TestResult +from .interval import NoGapConstraint, NoOverlapConstraint INPUT_DATE_FORMAT = "'%Y-%m-%d'" @@ -157,78 +157,9 @@ def compare( return result, assertion_text -class DateIntervals(Constraint, abc.ABC): - _DIMENSIONS = 0 - - def __init__( - self, - ref: DataReference, - key_columns: Optional[List[str]], - start_columns: List[str], - end_columns: List[str], - end_included: bool, - max_relative_n_violations: float, - name: str = None, - ): - super().__init__(ref, ref_value=object(), name=name) - self.key_columns = key_columns - self.start_columns = start_columns - self.end_columns = end_columns - self.end_included = end_included - self.max_relative_n_violations = max_relative_n_violations - self._validate_dimensions() - - @abc.abstractmethod - def select(self, engine: sa.engine.Engine, ref: DataReference): - pass - - def _validate_dimensions(self): - if (length := len(self.start_columns)) != self._DIMENSIONS: - raise ValueError( - f"Expected {self._DIMENSIONS} start_column(s), got {length}." - ) - if (length := len(self.end_columns)) != self._DIMENSIONS: - raise ValueError( - f"Expected {self._DIMENSIONS} end_column(s), got {length}." - ) - - def retrieve( - self, engine: sa.engine.Engine, ref: DataReference - ) -> Tuple[Tuple[int, int], OptionalSelections]: - keys_ref = DataReference( - data_source=self.ref.data_source, - columns=self.key_columns, - condition=self.ref.condition, - ) - n_distinct_key_values, n_keys_selections = db_access.get_unique_count( - engine, keys_ref - ) - - sample_selection, n_violations_selection = self.select(engine, ref) - with engine.connect() as connection: - self.sample = connection.execute(sample_selection).first() - n_violation_keys = connection.execute(n_violations_selection).scalar() - - selections = [*n_keys_selections, sample_selection, n_violations_selection] - return (n_violation_keys, n_distinct_key_values), selections - - -class DateNoOverlap(DateIntervals): +class DateNoOverlap(NoOverlapConstraint): _DIMENSIONS = 1 - def select(self, engine: sa.engine.Engine, ref: DataReference): - sample_selection, n_violations_selection = db_access.get_date_overlaps_nd( - engine, - ref, - self.key_columns, - start_columns=self.start_columns, - end_columns=self.end_columns, - end_included=self.end_included, - ) - # TODO: Once get_unique_count also only returns a selection without - # executing it, one would want to list this selection here as well. - return sample_selection, n_violations_selection - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: @@ -244,22 +175,9 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: return result, assertion_text -class DateNoOverlap2d(DateIntervals): +class DateNoOverlap2d(NoOverlapConstraint): _DIMENSIONS = 2 - def select(self, engine: sa.engine.Engine, ref: DataReference): - sample_selection, n_violations_selection = db_access.get_date_overlaps_nd( - engine, - ref, - self.key_columns, - start_columns=self.start_columns, - end_columns=self.end_columns, - end_included=self.end_included, - ) - # TODO: Once get_unique_count also only returns a selection without - # executing it, one would want to list this selection here as well. - return sample_selection, n_violations_selection - def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: n_violation_keys, n_distinct_key_values = factual if n_distinct_key_values == 0: @@ -276,7 +194,7 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: return result, assertion_text -class DateNoGap(DateIntervals): +class DateNoGap(NoGapConstraint): _DIMENSIONS = 1 def select(self, engine: sa.engine.Engine, ref: DataReference): @@ -286,7 +204,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference): self.key_columns, self.start_columns[0], self.end_columns[0], - self.end_included, + self.legitimate_gap_size, ) # TODO: Once get_unique_count also only returns a selection without # executing it, one would want to list this selection here as well. diff --git a/src/datajudge/constraints/interval.py b/src/datajudge/constraints/interval.py new file mode 100644 index 00000000..8e4b9c46 --- /dev/null +++ b/src/datajudge/constraints/interval.py @@ -0,0 +1,131 @@ +import abc +from typing import Any, List, Optional, Tuple + +import sqlalchemy as sa + +from .. import db_access +from ..db_access import DataReference +from .base import Constraint, OptionalSelections + + +class IntervalConstraint(Constraint): + _DIMENSIONS = 0 + + def __init__( + self, + ref: DataReference, + key_columns: Optional[List[str]], + start_columns: List[str], + end_columns: List[str], + max_relative_n_violations: float, + name: str = None, + ): + super().__init__(ref, ref_value=object(), name=name) + self.key_columns = key_columns + self.start_columns = start_columns + self.end_columns = end_columns + self.max_relative_n_violations = max_relative_n_violations + self._validate_dimensions() + + @abc.abstractmethod + def select(self, engine: sa.engine.Engine, ref: DataReference): + pass + + def _validate_dimensions(self): + if (length := len(self.start_columns)) != self._DIMENSIONS: + raise ValueError( + f"Expected {self._DIMENSIONS} start_column(s), got {length}." + ) + if (length := len(self.end_columns)) != self._DIMENSIONS: + raise ValueError( + f"Expected {self._DIMENSIONS} end_column(s), got {length}." + ) + + def retrieve( + self, engine: sa.engine.Engine, ref: DataReference + ) -> Tuple[Tuple[int, int], OptionalSelections]: + keys_ref = DataReference( + data_source=self.ref.data_source, + columns=self.key_columns, + condition=self.ref.condition, + ) + n_distinct_key_values, n_keys_selections = db_access.get_unique_count( + engine, keys_ref + ) + + sample_selection, n_violations_selection = self.select(engine, ref) + with engine.connect() as connection: + self.sample = connection.execute(sample_selection).first() + n_violation_keys = connection.execute(n_violations_selection).scalar() + + selections = [*n_keys_selections, sample_selection, n_violations_selection] + return (n_violation_keys, n_distinct_key_values), selections + + +class NoOverlapConstraint(IntervalConstraint): + def __init__( + self, + ref: DataReference, + key_columns: Optional[List[str]], + start_columns: List[str], + end_columns: List[str], + max_relative_n_violations: float, + end_included: bool, + name: Optional[str] = None, + ): + self.end_included = end_included + super().__init__( + ref, + key_columns, + start_columns, + end_columns, + max_relative_n_violations, + name=name, + ) + + def select(self, engine: sa.engine.Engine, ref: DataReference): + sample_selection, n_violations_selection = db_access.get_interval_overlaps_nd( + engine, + ref, + self.key_columns, + start_columns=self.start_columns, + end_columns=self.end_columns, + end_included=self.end_included, + ) + # TODO: Once get_unique_count also only returns a selection without + # executing it, one would want to list this selection here as well. + return sample_selection, n_violations_selection + + @abc.abstractmethod + def compare(self, engine: sa.engine.Engine, ref: DataReference): + pass + + +class NoGapConstraint(IntervalConstraint): + def __init__( + self, + ref: DataReference, + key_columns: Optional[List[str]], + start_columns: List[str], + end_columns: List[str], + max_relative_n_violations: float, + legitimate_gap_size: float, + name: Optional[str] = None, + ): + self.legitimate_gap_size = legitimate_gap_size + super().__init__( + ref, + key_columns, + start_columns, + end_columns, + max_relative_n_violations, + name=name, + ) + + @abc.abstractmethod + def select(self, engine: sa.engine.Engine, ref: DataReference): + pass + + @abc.abstractmethod + def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + pass diff --git a/src/datajudge/constraints/numeric.py b/src/datajudge/constraints/numeric.py index 9b372058..d1987a57 100644 --- a/src/datajudge/constraints/numeric.py +++ b/src/datajudge/constraints/numeric.py @@ -1,10 +1,11 @@ -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import sqlalchemy as sa from .. import db_access from ..db_access import DataReference from .base import Constraint, OptionalSelections, TestResult +from .interval import NoGapConstraint, NoOverlapConstraint class NumericMin(Constraint): @@ -236,3 +237,52 @@ def compare( ) return False, assertion_message return True, None + + +class NumericNoGap(NoGapConstraint): + _DIMENSIONS = 1 + + def select(self, engine: sa.engine.Engine, ref: DataReference): + sample_selection, n_violations_selection = db_access.get_numeric_gaps( + engine, + ref, + self.key_columns, + self.start_columns[0], + self.end_columns[0], + self.legitimate_gap_size, + ) + # TODO: Once get_unique_count also only returns a selection without + # executing it, one would want to list this selection here as well. + return sample_selection, n_violations_selection + + def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + n_violation_keys, n_distinct_key_values = factual + if n_distinct_key_values == 0: + return TestResult.success() + violation_fraction = n_violation_keys / n_distinct_key_values + assertion_text = ( + f"{self.ref.get_string()} has a ratio of {violation_fraction} > " + f"{self.max_relative_n_violations} keys in columns {self.key_columns} " + f"with a gap in the range in {self.start_columns[0]} and {self.end_columns[0]}." + f"E.g. for: {self.sample}." + ) + result = violation_fraction <= self.max_relative_n_violations + return result, assertion_text + + +class NumericNoOverlap(NoOverlapConstraint): + _DIMENSIONS = 1 + + def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]: + n_violation_keys, n_distinct_key_values = factual + if n_distinct_key_values == 0: + return TestResult.success() + violation_fraction = n_violation_keys / n_distinct_key_values + assertion_text = ( + f"{self.ref.get_string()} has a ratio of {violation_fraction} > " + f"{self.max_relative_n_violations} keys in columns {self.key_columns} " + f"with overlapping ranges in {self.start_columns[0]} and {self.end_columns[0]}." + f"E.g. for: {self.sample}." + ) + result = violation_fraction <= self.max_relative_n_violations + return result, assertion_text diff --git a/src/datajudge/db_access.py b/src/datajudge/db_access.py index c07006ec..6c2cb92c 100644 --- a/src/datajudge/db_access.py +++ b/src/datajudge/db_access.py @@ -459,7 +459,7 @@ def get_date_growth_rate(engine, ref, ref2, date_column, date_column2): return date_span / date_span2 - 1, [*selections, *selections2] -def get_date_overlaps_nd( +def get_interval_overlaps_nd( engine: sa.engine.Engine, ref: DataReference, key_columns: list[str] | None, @@ -485,7 +485,7 @@ def get_date_overlaps_nd( key_conditions = ( [table1.c[key_column] == table2.c[key_column] for key_column in key_columns] if key_columns - else [True] + else [sa.literal(True)] ) table_key_columns = get_table_columns(table1, key_columns) if key_columns else [] @@ -570,13 +570,16 @@ def _not_in_interval_condition( ) -def get_date_gaps( +def _get_interval_gaps( engine: sa.engine.Engine, ref: DataReference, key_columns: list[str] | None, start_column: str, end_column: str, - end_included: bool, + legitimate_gap_size: float, + make_gap_condition: Callable[ + [sa.Engine, sa.Subquery, sa.Subquery, str, str, float], sa.ColumnElement[bool] + ], ): if is_snowflake(engine): if key_columns: @@ -635,8 +638,46 @@ def get_date_gaps( .subquery() ) - legitimate_gap_size = 1 if end_included else 0 + gap_condition = make_gap_condition( + engine, start_table, end_table, start_column, end_column, legitimate_gap_size + ) + + join_condition = sa.and_( + *[ + start_table.c[key_column] == end_table.c[key_column] + for key_column in key_columns + ], + start_table.c["start_rank"] == end_table.c["end_rank"] + 1, + gap_condition, + ) + + violation_selection = sa.select( + *get_table_columns(start_table, key_columns), + start_table.c[start_column], + end_table.c[end_column], + ).select_from(start_table.join(end_table, join_condition)) + + violation_subquery = violation_selection.subquery() + + keys = get_table_columns(violation_subquery, key_columns) + + grouped_violation_subquery = sa.select(*keys).group_by(*keys).subquery() + + n_violations_selection = sa.select(sa.func.count()).select_from( + grouped_violation_subquery + ) + + return violation_selection, n_violations_selection + +def _date_gap_condition( + engine: sa.engine.Engine, + start_table: sa.Subquery, + end_table: sa.Subquery, + start_column: str, + end_column: str, + legitimate_gap_size: float, +) -> sa.ColumnElement[bool]: if is_mssql(engine) or is_snowflake(engine): gap_condition = ( sa.func.datediff( @@ -686,34 +727,60 @@ def get_date_gaps( ) else: raise NotImplementedError(f"Date gaps not yet implemented for {engine.name}.") + return gap_condition - join_condition = sa.and_( - *[ - start_table.c[key_column] == end_table.c[key_column] - for key_column in key_columns - ], - start_table.c["start_rank"] == end_table.c["end_rank"] + 1, - gap_condition, - ) - violation_selection = sa.select( - *get_table_columns(start_table, key_columns), - start_table.c[start_column], - end_table.c[end_column], - ).select_from(start_table.join(end_table, join_condition)) +def get_date_gaps( + engine: sa.engine.Engine, + ref: DataReference, + key_columns: list[str] | None, + start_column: str, + end_column: str, + legitimate_gap_size: float, +): + return _get_interval_gaps( + engine, + ref, + key_columns, + start_column, + end_column, + legitimate_gap_size, + _date_gap_condition, + ) - violation_subquery = violation_selection.subquery() - keys = get_table_columns(violation_subquery, key_columns) +def _numeric_gap_condition( + _engine: sa.engine.Engine, + start_table: sa.Subquery, + end_table: sa.Subquery, + start_column: str, + end_column: str, + legitimate_gap_size: float, +) -> sa.ColumnElement[bool]: + gap_condition = ( + start_table.c[start_column] - end_table.c[end_column] + ) > legitimate_gap_size + return gap_condition - grouped_violation_subquery = sa.select(*keys).group_by(*keys).subquery() - n_violations_selection = sa.select(sa.func.count()).select_from( - grouped_violation_subquery +def get_numeric_gaps( + engine: sa.engine.Engine, + ref: DataReference, + key_columns: list[str] | None, + start_column: str, + end_column: str, + legitimate_gap_size: float = 0, +): + return _get_interval_gaps( + engine, + ref, + key_columns, + start_column, + end_column, + legitimate_gap_size, + _numeric_gap_condition, ) - return violation_selection, n_violations_selection - def get_row_count(engine, ref, row_limit: int = None): """Return the number of rows for a `DataReference`. diff --git a/src/datajudge/requirements.py b/src/datajudge/requirements.py index a242ede5..349afe08 100644 --- a/src/datajudge/requirements.py +++ b/src/datajudge/requirements.py @@ -801,7 +801,115 @@ def add_date_no_gap_constraint( start_columns=[start_column], end_columns=[end_column], max_relative_n_violations=max_relative_n_violations, + legitimate_gap_size=1 if end_included else 0, + name=name, + ) + ) + + def add_numeric_no_gap_constraint( + self, + start_column: str, + end_column: str, + key_columns: Optional[List[str]] = None, + legitimate_gap_size: float = 0, + max_relative_n_violations: float = 0, + condition: Condition = None, + name: str = None, + ): + """ + Express that numeric interval rows have no gaps larger than some max value in-between them. + The table under inspection must consist of at least one but up to many key columns, + identifying an entity. Additionally, a ``start_column`` and an ``end_column``, + indicating interval start and end values, should be provided. + + Neither of those columns should contain ``NULL`` values. Also, it should hold that + for a given row, the value of ``end_column`` is strictly greater than the value of + ``start_column``. + + ``legitimate_gap_size`` is the maximum tollerated gap size between two intervals. + + A 'key' is a fixed set of values in ``key_columns`` and represents an entity of + interest. A priori, a key is not a primary key, i.e., a key can have and often has + several rows. Thereby, a key will often come with several intervals. + + If`` key_columns`` is ``None`` or ``[]``, all columns of the table will be + considered as composing the key. + + In order to express a tolerance for some violations of this gap property, use the + ``max_relative_n_violations`` parameter. The latter expresses for what fraction + of all key_values, at least one gap may exist. + + For illustrative examples of this constraint, please refer to its test cases. + """ + relevant_columns = ( + ([start_column, end_column] + key_columns) if key_columns else [] + ) + ref = DataReference(self.data_source, relevant_columns, condition) + self._constraints.append( + numeric_constraints.NumericNoGap( + ref, + key_columns=key_columns, + start_columns=[start_column], + end_columns=[end_column], + legitimate_gap_size=legitimate_gap_size, + max_relative_n_violations=max_relative_n_violations, + name=name, + ) + ) + + def add_numeric_no_overlap_constraint( + self, + start_column: str, + end_column: str, + key_columns: Optional[List[str]] = None, + end_included: bool = True, + max_relative_n_violations: float = 0, + condition: Condition = None, + name: str = None, + ): + """Constraint expressing that several numeric interval rows may not overlap. + + The ``DataSource`` under inspection must consist of at least one but up + to many ``key_columns``, identifying an entity, a ``start_column`` and an + ``end_column``. + + For a given row in this ``DataSource``, ``start_column`` and ``end_column`` indicate a + numeric interval. Neither of those columns should contain NULL values. Also, it + should hold that for a given row, the value of ``end_column`` is strictly greater + than the value of ``start_column``. + + Note that the value of ``start_column`` is expected to be included in each interval. + By default, the value of ``end_column`` is expected to be included as well - + this can however be changed by setting ``end_included`` to ``False``. + + A 'key' is a fixed set of values in ``key_columns`` and represents an entity of + interest. A priori, a key is not a primary key, i.e., a key can have and often + has several rows. Thereby, a key will often come with several intervals. + + Often, you might want the intervals for a given key not to overlap. + + If ``key_columns`` is ``None`` or ``[]``, all columns of the table will be considered + as composing the key. + + In order to express a tolerance for some violations of this non-overlapping + property, use the ``max_relative_n_violations`` parameter. The latter expresses for + what fraction of all key values, at least one overlap may exist. + + For illustrative examples of this constraint, please refer to its test cases. + """ + + relevant_columns = [start_column, end_column] + ( + key_columns if key_columns else [] + ) + ref = DataReference(self.data_source, relevant_columns, condition) + self._constraints.append( + numeric_constraints.NumericNoOverlap( + ref, + key_columns=key_columns, + start_columns=[start_column], + end_columns=[end_column], end_included=end_included, + max_relative_n_violations=max_relative_n_violations, name=name, ) ) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 83392a3a..c7b31ff5 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -383,6 +383,75 @@ def date_table_overlap_2d(engine, metadata): return TEST_DB_NAME, SCHEMA, table_name +@pytest.fixture(scope="module") +def integer_table_overlap(engine, metadata): + table_name = "integer_table_overlap" + columns = [ + sa.Column("id1", sa.Integer()), + sa.Column("range_start", sa.Integer()), + sa.Column("range_end", sa.Integer()), + ] + data = [] + # Trivial case: single entry. + data += [ + { + "id1": 1, + "range_start": 1, + "range_end": 10, + } + ] + # 'Normal case': Multiple entries without overlap. + data += [ + { + "id1": 2, + "range_start": i * 2, + "range_end": i * 2 + 1, + } + for i in range(1, 5) + ] + # Multiple entries with non-singleton overlap. + data += [ + { + "id1": 3, + "range_start": 1, + "range_end": 10, + }, + { + "id1": 3, + "range_start": 7, + "range_end": 15, + }, + ] + # Multiple entries with singleton overlap. + data += [ + { + "id1": 4, + "range_start": 1, + "range_end": 10, + }, + { + "id1": 4, + "range_start": 10, + "range_end": 15, + }, + ] + # Multiple entries with subset relation. + data += [ + { + "id1": 5, + "range_start": 1, + "range_end": 10, + }, + { + "id1": 5, + "range_start": 4, + "range_end": 8, + }, + ] + _handle_table(engine, metadata, table_name, columns, data) + return TEST_DB_NAME, SCHEMA, table_name + + @pytest.fixture(scope="module") def date_table_gap(engine, metadata): table_name = "date_table_gap" @@ -452,6 +521,131 @@ def date_table_gap(engine, metadata): return TEST_DB_NAME, SCHEMA, table_name +@pytest.fixture(scope="module") +def integer_table_gap(engine, metadata): + table_name = "integer_table_gap" + columns = [ + sa.Column("id1", sa.Integer()), + sa.Column("range_start", sa.Integer()), + sa.Column("range_end", sa.Integer()), + ] + data = [] + # Single entry should not be considered a gap. + data += [ + { + "id1": 1, + "range_start": 1, + "range_end": 3, + } + ] + # Multiple entries without gap. + data += [ + { + "id1": 2, + "range_start": 3 + i * 2, + "range_end": 5 + i * 2, + } + for i in range(1, 5) + ] + # Multiple entries with overlap. + data += [ + { + "id1": 3, + "range_start": 1, + "range_end": 10, + }, + { + "id1": 3, + "range_start": 3, + "range_end": 7, + }, + ] + # Multiple entries with gap. + data += [ + { + "id1": 4, + "range_start": 1, + "range_end": 5, + }, + { + "id1": 4, + "range_start": 7, + "range_end": 10, + }, + ] + _handle_table(engine, metadata, table_name, columns, data) + return TEST_DB_NAME, SCHEMA, table_name + + +@pytest.fixture(scope="module") +def float_table_gap(engine, metadata): + table_name = "float_table_gap" + columns = [ + sa.Column("id1", sa.Integer()), + sa.Column("range_start", sa.Float()), + sa.Column("range_end", sa.Float()), + ] + data = [] + # Single entry should not be considered a gap. + data += [ + { + "id1": 1, + "range_start": 1, + "range_end": 3, + } + ] + # Multiple entries without gap. + data += [ + { + "id1": 2, + "range_start": 3 + i * 2, + "range_end": 5 + i * 2, + } + for i in range(1, 5) + ] + # Multiple entries with overlap. + data += [ + { + "id1": 3, + "range_start": 1, + "range_end": 10, + }, + { + "id1": 3, + "range_start": 3, + "range_end": 7, + }, + ] + # Multiple entries with gap. + data += [ + { + "id1": 4, + "range_start": 1, + "range_end": 5, + }, + { + "id1": 4, + "range_start": 8, + "range_end": 10, + }, + ] + # Multiple entries with tolerated gap. + data += [ + { + "id1": 5, + "range_start": 1, + "range_end": 5, + }, + { + "id1": 5, + "range_start": 5.5, + "range_end": 10, + }, + ] + _handle_table(engine, metadata, table_name, columns, data) + return TEST_DB_NAME, SCHEMA, table_name + + @pytest.fixture(scope="module") def date_table_keys(engine, metadata): table_name = "date_table_keys" diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index eda55edb..03229d65 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1169,6 +1169,36 @@ def test_date_no_overlap_within_varying_key_columns( assert operation(test_result.outcome), test_result.failure_message +@pytest.mark.parametrize( + "data", + [ + (identity, 0, Condition(raw_string="id1 = 1")), + (identity, 0, Condition(raw_string="id1 = 2")), + (negation, 0, Condition(raw_string="id1 = 3")), + (identity, 1, Condition(raw_string="id1 = 3")), + (negation, 0, Condition(raw_string="id1 = 4")), + (identity, 1, Condition(raw_string="id1 = 4")), + (negation, 0, Condition(raw_string="id1 = 5")), + (identity, 1, Condition(raw_string="id1 = 5")), + ], +) +@pytest.mark.parametrize("key_columns", [["id1"], [], None]) +def test_integer_no_overlap_within_varying_key_columns( + engine, integer_table_overlap, data, key_columns +): + operation, max_relative_n_violations, condition = data + req = requirements.WithinRequirement.from_table(*integer_table_overlap) + req.add_numeric_no_overlap_constraint( + key_columns=key_columns, + start_column="range_start", + end_column="range_end", + max_relative_n_violations=max_relative_n_violations, + condition=condition, + ) + test_result = req[0].test(engine) + assert operation(test_result.outcome), test_result.failure_message + + @pytest.mark.parametrize( "data", [ @@ -1376,6 +1406,57 @@ def test_date_no_gap_within_fixed_key_columns(engine, date_table_gap, data): assert operation(test_result.outcome), test_result.failure_message +@pytest.mark.parametrize( + "data", + [ + (identity, 0, Condition(raw_string="id1 = 1")), + (identity, 0, Condition(raw_string="id1 = 2")), + (identity, 0, Condition(raw_string="id1 = 3")), + (negation, 0, Condition(raw_string="id1 = 4")), + (identity, 0, None), + ], +) +def test_integer_no_gap_within_fixed_key_columns(engine, integer_table_gap, data): + operation, max_relative_n_violations, condition = data + req = requirements.WithinRequirement.from_table(*integer_table_gap) + req.add_numeric_no_gap_constraint( + key_columns=["id1"], + start_column="range_start", + end_column="range_end", + max_relative_n_violations=max_relative_n_violations, + legitimate_gap_size=0, + condition=condition, + ) + test_result = req[0].test(engine) + assert operation(test_result.outcome), test_result.failure_message + + +@pytest.mark.parametrize( + "data", + [ + (identity, 0, Condition(raw_string="id1 = 1")), + (identity, 0, Condition(raw_string="id1 = 2")), + (identity, 0, Condition(raw_string="id1 = 3")), + (negation, 0, Condition(raw_string="id1 = 4")), + (negation, 0, Condition(raw_string="id1 = 5")), + (identity, 0.6, Condition(raw_string="id1 = 5")), + ], +) +def test_float_no_gap_within_fixed_key_columns(engine, float_table_gap, data): + operation, legitimate_gap_size, condition = data + req = requirements.WithinRequirement.from_table(*float_table_gap) + req.add_numeric_no_gap_constraint( + key_columns=["id1"], + start_column="range_start", + end_column="range_end", + legitimate_gap_size=legitimate_gap_size, + max_relative_n_violations=0, + condition=condition, + ) + test_result = req[0].test(engine) + assert operation(test_result.outcome), test_result.failure_message + + @pytest.mark.parametrize( "data", [