Skip to content

Commit

Permalink
extend interval constraints to work with integers (#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bela Stoyan authored Jun 1, 2023
1 parent a23c329 commit 6ddaf82
Show file tree
Hide file tree
Showing 8 changed files with 664 additions and 115 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Changelog

- Implement :meth:`datajudge.WithinRequirement.add_value_distribution_constraint`.
- Extended :meth:`datajudge.WithinRequirement.add_column_type_constraint` to support column type specification using string format, backend-specific SQLAlchemy types, and SQLAlchemy's generic types.

- Implement :meth:`datajudge.WithinRequirement.add_numeric_no_gap_constraint`, :meth:`datajudge.WithinRequirement.add_numeric_no_overlap_constraint`,

1.6.0 - 2022.04.12
------------------
Expand Down
94 changes: 6 additions & 88 deletions src/datajudge/constraints/date.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import abc
import datetime as dt
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import sqlalchemy as sa

from .. import db_access
from ..db_access import DataReference
from .base import Constraint, OptionalSelections, TestResult
from .interval import NoGapConstraint, NoOverlapConstraint

INPUT_DATE_FORMAT = "'%Y-%m-%d'"

Expand Down Expand Up @@ -157,78 +157,9 @@ def compare(
return result, assertion_text


class DateIntervals(Constraint, abc.ABC):
_DIMENSIONS = 0

def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
end_included: bool,
max_relative_n_violations: float,
name: str = None,
):
super().__init__(ref, ref_value=object(), name=name)
self.key_columns = key_columns
self.start_columns = start_columns
self.end_columns = end_columns
self.end_included = end_included
self.max_relative_n_violations = max_relative_n_violations
self._validate_dimensions()

@abc.abstractmethod
def select(self, engine: sa.engine.Engine, ref: DataReference):
pass

def _validate_dimensions(self):
if (length := len(self.start_columns)) != self._DIMENSIONS:
raise ValueError(
f"Expected {self._DIMENSIONS} start_column(s), got {length}."
)
if (length := len(self.end_columns)) != self._DIMENSIONS:
raise ValueError(
f"Expected {self._DIMENSIONS} end_column(s), got {length}."
)

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Tuple[int, int], OptionalSelections]:
keys_ref = DataReference(
data_source=self.ref.data_source,
columns=self.key_columns,
condition=self.ref.condition,
)
n_distinct_key_values, n_keys_selections = db_access.get_unique_count(
engine, keys_ref
)

sample_selection, n_violations_selection = self.select(engine, ref)
with engine.connect() as connection:
self.sample = connection.execute(sample_selection).first()
n_violation_keys = connection.execute(n_violations_selection).scalar()

selections = [*n_keys_selections, sample_selection, n_violations_selection]
return (n_violation_keys, n_distinct_key_values), selections


class DateNoOverlap(DateIntervals):
class DateNoOverlap(NoOverlapConstraint):
_DIMENSIONS = 1

def select(self, engine: sa.engine.Engine, ref: DataReference):
sample_selection, n_violations_selection = db_access.get_date_overlaps_nd(
engine,
ref,
self.key_columns,
start_columns=self.start_columns,
end_columns=self.end_columns,
end_included=self.end_included,
)
# TODO: Once get_unique_count also only returns a selection without
# executing it, one would want to list this selection here as well.
return sample_selection, n_violations_selection

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
Expand All @@ -244,22 +175,9 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
return result, assertion_text


class DateNoOverlap2d(DateIntervals):
class DateNoOverlap2d(NoOverlapConstraint):
_DIMENSIONS = 2

def select(self, engine: sa.engine.Engine, ref: DataReference):
sample_selection, n_violations_selection = db_access.get_date_overlaps_nd(
engine,
ref,
self.key_columns,
start_columns=self.start_columns,
end_columns=self.end_columns,
end_included=self.end_included,
)
# TODO: Once get_unique_count also only returns a selection without
# executing it, one would want to list this selection here as well.
return sample_selection, n_violations_selection

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
Expand All @@ -276,7 +194,7 @@ def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
return result, assertion_text


class DateNoGap(DateIntervals):
class DateNoGap(NoGapConstraint):
_DIMENSIONS = 1

def select(self, engine: sa.engine.Engine, ref: DataReference):
Expand All @@ -286,7 +204,7 @@ def select(self, engine: sa.engine.Engine, ref: DataReference):
self.key_columns,
self.start_columns[0],
self.end_columns[0],
self.end_included,
self.legitimate_gap_size,
)
# TODO: Once get_unique_count also only returns a selection without
# executing it, one would want to list this selection here as well.
Expand Down
131 changes: 131 additions & 0 deletions src/datajudge/constraints/interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import abc
from typing import Any, List, Optional, Tuple

import sqlalchemy as sa

from .. import db_access
from ..db_access import DataReference
from .base import Constraint, OptionalSelections


class IntervalConstraint(Constraint):
_DIMENSIONS = 0

def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
max_relative_n_violations: float,
name: str = None,
):
super().__init__(ref, ref_value=object(), name=name)
self.key_columns = key_columns
self.start_columns = start_columns
self.end_columns = end_columns
self.max_relative_n_violations = max_relative_n_violations
self._validate_dimensions()

@abc.abstractmethod
def select(self, engine: sa.engine.Engine, ref: DataReference):
pass

def _validate_dimensions(self):
if (length := len(self.start_columns)) != self._DIMENSIONS:
raise ValueError(
f"Expected {self._DIMENSIONS} start_column(s), got {length}."
)
if (length := len(self.end_columns)) != self._DIMENSIONS:
raise ValueError(
f"Expected {self._DIMENSIONS} end_column(s), got {length}."
)

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[Tuple[int, int], OptionalSelections]:
keys_ref = DataReference(
data_source=self.ref.data_source,
columns=self.key_columns,
condition=self.ref.condition,
)
n_distinct_key_values, n_keys_selections = db_access.get_unique_count(
engine, keys_ref
)

sample_selection, n_violations_selection = self.select(engine, ref)
with engine.connect() as connection:
self.sample = connection.execute(sample_selection).first()
n_violation_keys = connection.execute(n_violations_selection).scalar()

selections = [*n_keys_selections, sample_selection, n_violations_selection]
return (n_violation_keys, n_distinct_key_values), selections


class NoOverlapConstraint(IntervalConstraint):
def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
max_relative_n_violations: float,
end_included: bool,
name: Optional[str] = None,
):
self.end_included = end_included
super().__init__(
ref,
key_columns,
start_columns,
end_columns,
max_relative_n_violations,
name=name,
)

def select(self, engine: sa.engine.Engine, ref: DataReference):
sample_selection, n_violations_selection = db_access.get_interval_overlaps_nd(
engine,
ref,
self.key_columns,
start_columns=self.start_columns,
end_columns=self.end_columns,
end_included=self.end_included,
)
# TODO: Once get_unique_count also only returns a selection without
# executing it, one would want to list this selection here as well.
return sample_selection, n_violations_selection

@abc.abstractmethod
def compare(self, engine: sa.engine.Engine, ref: DataReference):
pass


class NoGapConstraint(IntervalConstraint):
def __init__(
self,
ref: DataReference,
key_columns: Optional[List[str]],
start_columns: List[str],
end_columns: List[str],
max_relative_n_violations: float,
legitimate_gap_size: float,
name: Optional[str] = None,
):
self.legitimate_gap_size = legitimate_gap_size
super().__init__(
ref,
key_columns,
start_columns,
end_columns,
max_relative_n_violations,
name=name,
)

@abc.abstractmethod
def select(self, engine: sa.engine.Engine, ref: DataReference):
pass

@abc.abstractmethod
def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
pass
52 changes: 51 additions & 1 deletion src/datajudge/constraints/numeric.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional, Tuple
from typing import Any, Optional, Tuple

import sqlalchemy as sa

from .. import db_access
from ..db_access import DataReference
from .base import Constraint, OptionalSelections, TestResult
from .interval import NoGapConstraint, NoOverlapConstraint


class NumericMin(Constraint):
Expand Down Expand Up @@ -236,3 +237,52 @@ def compare(
)
return False, assertion_message
return True, None


class NumericNoGap(NoGapConstraint):
_DIMENSIONS = 1

def select(self, engine: sa.engine.Engine, ref: DataReference):
sample_selection, n_violations_selection = db_access.get_numeric_gaps(
engine,
ref,
self.key_columns,
self.start_columns[0],
self.end_columns[0],
self.legitimate_gap_size,
)
# TODO: Once get_unique_count also only returns a selection without
# executing it, one would want to list this selection here as well.
return sample_selection, n_violations_selection

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
violation_fraction = n_violation_keys / n_distinct_key_values
assertion_text = (
f"{self.ref.get_string()} has a ratio of {violation_fraction} > "
f"{self.max_relative_n_violations} keys in columns {self.key_columns} "
f"with a gap in the range in {self.start_columns[0]} and {self.end_columns[0]}."
f"E.g. for: {self.sample}."
)
result = violation_fraction <= self.max_relative_n_violations
return result, assertion_text


class NumericNoOverlap(NoOverlapConstraint):
_DIMENSIONS = 1

def compare(self, factual: Tuple[int, int], target: Any) -> Tuple[bool, str]:
n_violation_keys, n_distinct_key_values = factual
if n_distinct_key_values == 0:
return TestResult.success()
violation_fraction = n_violation_keys / n_distinct_key_values
assertion_text = (
f"{self.ref.get_string()} has a ratio of {violation_fraction} > "
f"{self.max_relative_n_violations} keys in columns {self.key_columns} "
f"with overlapping ranges in {self.start_columns[0]} and {self.end_columns[0]}."
f"E.g. for: {self.sample}."
)
result = violation_fraction <= self.max_relative_n_violations
return result, assertion_text
Loading

0 comments on commit 6ddaf82

Please sign in to comment.