Skip to content

Commit

Permalink
allow to specify sqlalchemy types directly in ColumnType Constraints (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Bela Stoyan authored May 17, 2023
1 parent 697697c commit a23c329
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changelog
**New features**

- Implement :meth:`datajudge.WithinRequirement.add_value_distribution_constraint`.
- Extended :meth:`datajudge.WithinRequirement.add_column_type_constraint` to support column type specification using string format, backend-specific SQLAlchemy types, and SQLAlchemy's generic types.


1.6.0 - 2022.04.12
Expand Down
37 changes: 28 additions & 9 deletions src/datajudge/constraints/column.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import List, Tuple
from typing import List, Optional, Tuple, Union

import sqlalchemy as sa

Expand Down Expand Up @@ -69,29 +69,48 @@ def compare(


class ColumnType(Constraint):
"""
A class used to represent a ColumnType constraint.
This class enables flexible specification of column types either in string format or using SQLAlchemy's type hierarchy.
It checks whether a column's type matches the specified type, allowing for checks against backend-specific types,
SQLAlchemy's generic types, or string representations of backend-specific types.
When using SQLAlchemy's generic types, the comparison is done using `isinstance`, which means that the actual type can also be a subclass of the target type.
For more information, see https://docs.sqlalchemy.org/en/20/core/type_basics.html
"""

def __init__(
self,
ref: DataReference,
*,
ref2: DataReference = None,
column_type: str = None,
name: str = None,
ref2: Optional[DataReference] = None,
column_type: Optional[Union[str, sa.types.TypeEngine]] = None,
name: Optional[str] = None,
):
if column_type:
column_type = column_type.lower()
super().__init__(ref, ref2=ref2, ref_value=column_type, name=name)
self.column_type = column_type

def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> Tuple[str, OptionalSelections]:
) -> Tuple[sa.types.TypeEngine, OptionalSelections]:
result, selections = db_access.get_column_type(engine, ref)
return result.lower(), selections
return result, selections

def compare(self, column_type_factual, column_type_target) -> Tuple[bool, str]:
assertion_message = (
f"{self.ref.get_string()} is {column_type_factual} "
f"instead of {column_type_target}."
)
result = column_type_factual.startswith(column_type_target)

if isinstance(column_type_target, sa.types.TypeEngine):
result = isinstance(column_type_factual, type(column_type_target))
else:
column_type = str(column_type_factual).lower()
# Integer columns loaded from snowflake database may be referred to as decimal with
# 0 scale. More here:
# https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric
if column_type == "decimal(38, 0)":
column_type = "integer"
result = column_type.startswith(column_type_target.lower())
return result, assertion_message
12 changes: 1 addition & 11 deletions src/datajudge/db_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,17 +908,7 @@ def get_column_names(engine, ref):

def get_column_type(engine, ref):
table = ref.get_selection(engine).alias()
if is_snowflake(engine):
column_type = [str(column.type) for column in table.columns][0]
# Integer columns loaded from snowflake database may be referred to as decimal with
# 0 scale. More here:
# https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric
if column_type == "DECIMAL(38, 0)":
column_type = "integer"
return column_type, None
column_type = [
str(column.type).split(" ", maxsplit=1)[0] for column in table.columns
][0]
column_type = next(iter(table.columns)).type
return column_type, None


Expand Down
55 changes: 43 additions & 12 deletions src/datajudge/requirements.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from abc import ABC
from collections.abc import MutableSequence
from typing import Callable, Collection, Dict, List, Optional, Sequence, Tuple, TypeVar
from typing import (
Callable,
Collection,
Dict,
List,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

import sqlalchemy as sa

Expand Down Expand Up @@ -166,8 +176,30 @@ def add_uniqueness_constraint(
)

def add_column_type_constraint(
self, column: str, column_type: str, name: str = None
self,
column: str,
column_type: Union[str, sa.types.TypeEngine],
name: str = None,
):
"""
Check if a column type matches the expected column_type.
The column_type can be provided as a string (backend-specific type name), a backend-specific SQLAlchemy type, or a SQLAlchemy's generic type.
If SQLAlchemy's generic types are used, the check is performed using `isinstance`, which means that the actual type can also be a subclass of the target type.
For more information on SQLAlchemy's generic types, see https://docs.sqlalchemy.org/en/20/core/type_basics.html
Parameters
----------
column : str
The name of the column to which the constraint will be applied.
column_type : Union[str, sa.types.TypeEngine]
The expected type of the column. This can be a string, a backend-specific SQLAlchemy type, or a generic SQLAlchemy type.
name : Optional[str]
An optional name for the constraint. If not provided, a name will be generated automatically.
"""
ref = DataReference(self.data_source, [column])
self._constraints.append(
column_constraints.ColumnType(ref, column_type=column_type, name=name)
Expand Down Expand Up @@ -517,16 +549,15 @@ def add_date_min_constraint(
column: str,
min_value: str,
use_lower_bound_reference: bool = True,
column_type: str = "date",
column_type: Union[str, sa.types.TypeEngine] = "date",
condition: Condition = None,
name: str = None,
):
"""Ensure all dates to be superior than min_value.
Use string format: min_value="'20121230'".
For valid ``column_type`` values, see`` get_format_from_column_type`` in
constraints/base.py.
For more information on ``column_type`` values, see ``add_column_type_constraint``.
If ``use_lower_bound_reference``, the min of the first table has to be
greater or equal to ``min_value``.
Expand All @@ -549,16 +580,15 @@ def add_date_max_constraint(
column: str,
max_value: str,
use_upper_bound_reference: bool = True,
column_type: str = "date",
column_type: Union[str, sa.types.TypeEngine] = "date",
condition: Condition = None,
name: str = None,
):
"""Ensure all dates to be superior than max_value.
Use string format: max_value="'20121230'".
For valid ``column_type`` values, see ``get_format_from_column_type`` in
constraints/base.py..
For more information on ``column_type`` values, see ``add_column_type_constraint``.
If ``use_upper_bound_reference``, the max of the first table has to be
smaller or equal to ``max_value``.
Expand Down Expand Up @@ -1427,7 +1457,7 @@ def add_date_min_constraint(
column1: str,
column2: str,
use_lower_bound_reference: bool = True,
column_type: str = "date",
column_type: Union[str, sa.types.TypeEngine] = "date",
condition1: Condition = None,
condition2: Condition = None,
name: str = None,
Expand All @@ -1436,7 +1466,7 @@ def add_date_min_constraint(
The used columns of both tables need to be of the same type.
For valid column_type values, see get_format_from_column_type in constraints/base.py..
For more information on ``column_type`` values, see ``add_column_type_constraint``.
If ``use_lower_bound_reference``, the min of the first table has to be
greater or equal to the min of the second table.
Expand All @@ -1460,7 +1490,7 @@ def add_date_max_constraint(
column1: str,
column2: str,
use_upper_bound_reference: bool = True,
column_type: str = "date",
column_type: Union[str, sa.types.TypeEngine] = "date",
condition1: Condition = None,
condition2: Condition = None,
name: str = None,
Expand All @@ -1469,7 +1499,7 @@ def add_date_max_constraint(
The used columns of both tables need to be of the same type.
For valid column_type values, see get_format_from_column_type in constraints/base.py.
For more information on ``column_type`` values, see ``add_column_type_constraint``.
If ``use_upper_bound_reference``, the max of the first table has to be
smaller or equal to the max of the second table.
Expand Down Expand Up @@ -1529,6 +1559,7 @@ def add_column_superset_constraint(self, name: str = None):
)

def add_column_type_constraint(self, column1: str, column2: str, name: str = None):
"Check that the columns have the same type."
ref1 = DataReference(self.data_source, [column1])
ref2 = DataReference(self.data_source2, [column2])
self._constraints.append(
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

import pytest
import sqlalchemy as sa

import datajudge.requirements as requirements
from datajudge.db_access import (
Expand Down Expand Up @@ -1942,11 +1943,13 @@ def test_max_null_fraction_between(engine, unique_table1, data):
(identity, "col_varchar", "VARCHAR"),
(identity, "col_int", "INTEGER"),
(negation, "col_varchar", "INTEGER"),
(identity, "col_varchar", sa.types.String()),
(negation, "col_varchar", sa.types.Numeric()),
],
)
def test_column_type_within(engine, mix_table1, data):
(operation, col_name, type_name) = data
if is_impala(engine):
if is_impala(engine) and type(type_name) == str:
type_name = {"VARCHAR": "string", "INTEGER": "int"}[type_name]
req = requirements.WithinRequirement.from_table(*mix_table1)
req.add_column_type_constraint(col_name, type_name)
Expand Down

0 comments on commit a23c329

Please sign in to comment.