Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use metadata over singletablemetadata #2144

Closed
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MissingConstraintColumnError,
)
from sdv.errors import ConstraintsNotMetError
from sdv.metadata.metadata import Metadata

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,12 +148,13 @@ def _validate_inputs(cls, **kwargs):

@classmethod
def _validate_metadata_columns(cls, metadata, **kwargs):
Metadata._convert_to_unified_metadata(metadata)
if 'column_name' in kwargs:
column_names = [kwargs.get('column_name')]
else:
column_names = kwargs.get('column_names')

missing_columns = set(column_names) - set(metadata.columns) - {None}
missing_columns = set(column_names) - set(metadata.get_columns()) - {None}
if missing_columns:
article = 'An' if cls.__name__ == 'Inequality' else 'A'
raise ConstraintMetadataError(
Expand All @@ -169,8 +171,8 @@ def _validate_metadata(cls, metadata, **kwargs):
"""Validate the metadata against the constraint.

Args:
metadata (sdv.metadata.SingleTableMetadata):
Single table metadata instance.
metadata (sdv.metadata.Metadata):
Metadata instance with a single table.
**kwargs (dict):
Any required kwargs for the constraint.

Expand All @@ -179,6 +181,7 @@ def _validate_metadata(cls, metadata, **kwargs):
All the errors from validating the metadata.
"""
errors = []
metadata = Metadata._convert_to_unified_metadata(metadata)
try:
cls._validate_inputs(**kwargs)
except AggregateConstraintsError as agg_error:
Expand Down
83 changes: 52 additions & 31 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
revert_nans_columns,
sigmoid,
)
from sdv.metadata.metadata import Metadata

INEQUALITY_TO_OPERATION = {
'>': np.greater,
Expand Down Expand Up @@ -251,10 +252,11 @@ class FixedCombinations(Constraint):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
invalid_columns = []
column_names = kwargs.get('column_names')
for column in column_names:
if metadata.columns[column]['sdtype'] not in ['boolean', 'categorical']:
if metadata.get_columns()[column]['sdtype'] not in ['boolean', 'categorical']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be backwards compatible. SingleTableMetadata does not have get_columns().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it backwards compatible by making sure all metadata is converted and also convert it the synthesizer metadata on load.

Also added an integration test for backward compatibility.

invalid_columns.append(column)

if invalid_columns:
Expand Down Expand Up @@ -390,15 +392,17 @@ def _validate_init_inputs(low_column_name, high_column_name, strict_boundaries):

@classmethod
def _validate_metadata_columns(cls, metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
kwargs['column_names'] = [kwargs.get('high_column_name'), kwargs.get('low_column_name')]
super()._validate_metadata_columns(metadata, **kwargs)

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
high_sdtype = metadata.columns.get(high, {}).get('sdtype')
low_sdtype = metadata.columns.get(low, {}).get('sdtype')
high_sdtype = metadata.get_columns().get(high, {}).get('sdtype')
low_sdtype = metadata.get_columns().get(low, {}).get('sdtype')
both_datetime = high_sdtype == low_sdtype == 'datetime'
both_numerical = high_sdtype == low_sdtype == 'numerical'
if not (both_datetime or both_numerical) and not (high is None or low is None):
Expand Down Expand Up @@ -426,8 +430,10 @@ def _get_data(self, table_data):
return low, high

def _get_is_datetime(self):
is_low_datetime = self.metadata.columns[self._low_column_name]['sdtype'] == 'datetime'
is_high_datetime = self.metadata.columns[self._high_column_name]['sdtype'] == 'datetime'
is_low_datetime = self.metadata.get_columns()[self._low_column_name]['sdtype'] == 'datetime'
is_high_datetime = (
self.metadata.get_columns()[self._high_column_name]['sdtype'] == 'datetime'
)
is_datetime = is_low_datetime and is_high_datetime

if not is_datetime and any([is_low_datetime, is_high_datetime]):
Expand All @@ -451,10 +457,10 @@ def _fit(self, table_data):
self._dtype = table_data[self._high_column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self._low_column_name].get(
self._low_datetime_format = self.metadata.get_columns()[self._low_column_name].get(
'datetime_format'
)
self._high_datetime_format = self.metadata.columns[self._high_column_name].get(
self._high_datetime_format = self.metadata.get_columns()[self._high_column_name].get(
'datetime_format'
)

Expand Down Expand Up @@ -597,15 +603,16 @@ def _validate_inputs(cls, **kwargs):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
column_name = kwargs.get('column_name')
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
sdtype = metadata.get_columns().get(column_name, {}).get('sdtype')
value = kwargs.get('value')
if sdtype == 'numerical':
if not isinstance(value, (int, float)):
raise ConstraintMetadataError("'value' must be an int or float.")

elif sdtype == 'datetime':
datetime_format = metadata.columns.get(column_name).get('datetime_format')
datetime_format = metadata.get_columns().get(column_name).get('datetime_format')
matches_format = matches_datetime_format(value, datetime_format)
if not matches_format:
raise ConstraintMetadataError(
Expand Down Expand Up @@ -647,7 +654,7 @@ def __init__(self, column_name, relation, value):
self._operator = INEQUALITY_TO_OPERATION[relation]

def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_column_datetime = self.metadata.get_columns()[self._column_name]['sdtype'] == 'datetime'
is_value_datetime = _is_datetime_type(self._value)
is_datetime = is_column_datetime and is_value_datetime

Expand All @@ -671,7 +678,9 @@ def _fit(self, table_data):
self._dtype = table_data[self._column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._datetime_format = self.metadata.columns[self._column_name].get('datetime_format')
self._datetime_format = self.metadata.get_columns()[self._column_name].get(
'datetime_format'
)

def is_valid(self, table_data):
"""Say whether ``high`` is greater than ``low`` in each row.
Expand Down Expand Up @@ -768,8 +777,9 @@ class Positive(ScalarInequality):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
column_name = kwargs.get('column_name')
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
sdtype = metadata.get_columns().get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A Positive constraint is being applied to an invalid column '
Expand Down Expand Up @@ -797,8 +807,9 @@ class Negative(ScalarInequality):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
column_name = kwargs.get('column_name')
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
sdtype = metadata.get_columns().get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A Negative constraint is being applied to an invalid column '
Expand Down Expand Up @@ -835,6 +846,7 @@ class Range(Constraint):

@classmethod
def _validate_metadata_columns(cls, metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
middle = kwargs.get('middle_column_name')
Expand All @@ -843,12 +855,13 @@ def _validate_metadata_columns(cls, metadata, **kwargs):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
middle = kwargs.get('middle_column_name')
high_sdtype = metadata.columns.get(high, {}).get('sdtype')
low_sdtype = metadata.columns.get(low, {}).get('sdtype')
middle_sdtype = metadata.columns.get(middle, {}).get('sdtype')
high_sdtype = metadata.get_columns().get(high, {}).get('sdtype')
low_sdtype = metadata.get_columns().get(low, {}).get('sdtype')
middle_sdtype = metadata.get_columns().get(middle, {}).get('sdtype')
all_datetime = high_sdtype == low_sdtype == middle_sdtype == 'datetime'
all_numerical = high_sdtype == low_sdtype == middle_sdtype == 'numerical'
if not (all_datetime or all_numerical) and not (
Expand Down Expand Up @@ -876,9 +889,13 @@ def __init__(
self._operator = operator.lt if strict_boundaries else operator.le

def _get_is_datetime(self):
is_low_datetime = self.metadata.columns[self.low_column_name]['sdtype'] == 'datetime'
is_middle_datetime = self.metadata.columns[self.middle_column_name]['sdtype'] == 'datetime'
is_high_datetime = self.metadata.columns[self.high_column_name]['sdtype'] == 'datetime'
is_low_datetime = self.metadata.get_columns()[self.low_column_name]['sdtype'] == 'datetime'
is_middle_datetime = (
self.metadata.get_columns()[self.middle_column_name]['sdtype'] == 'datetime'
)
is_high_datetime = (
self.metadata.get_columns()[self.high_column_name]['sdtype'] == 'datetime'
)
is_datetime = is_low_datetime and is_high_datetime and is_middle_datetime

if not is_datetime and any([is_low_datetime, is_middle_datetime, is_high_datetime]):
Expand All @@ -896,13 +913,13 @@ def _fit(self, table_data):
self._dtype = table_data[self.middle_column_name].dtypes
self._is_datetime = self._get_is_datetime()
if self._is_datetime:
self._low_datetime_format = self.metadata.columns[self.low_column_name].get(
self._low_datetime_format = self.metadata.get_columns()[self.low_column_name].get(
'datetime_format'
)
self._middle_datetime_format = self.metadata.columns[self.middle_column_name].get(
self._middle_datetime_format = self.metadata.get_columns()[self.middle_column_name].get(
'datetime_format'
)
self._high_datetime_format = self.metadata.columns[self.high_column_name].get(
self._high_datetime_format = self.metadata.get_columns()[self.high_column_name].get(
'datetime_format'
)

Expand Down Expand Up @@ -1079,13 +1096,14 @@ def _validate_init_inputs(low_value, high_value):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
column_name = kwargs.get('column_name')
if column_name not in metadata.columns:
if column_name not in metadata.get_columns():
raise ConstraintMetadataError(
f'A ScalarRange constraint is being applied to invalid column names '
f'({column_name}). The columns must exist in the table.'
)
sdtype = metadata.columns.get(column_name).get('sdtype')
sdtype = metadata.get_columns().get(column_name).get('sdtype')
high_value = kwargs.get('high_value')
low_value = kwargs.get('low_value')
if sdtype == 'numerical':
Expand All @@ -1095,7 +1113,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs):
)

elif sdtype == 'datetime':
datetime_format = metadata.columns.get(column_name, {}).get('datetime_format')
datetime_format = metadata.get_columns().get(column_name, {}).get('datetime_format')
high_matches_format = matches_datetime_format(high_value, datetime_format)
low_matches_format = matches_datetime_format(low_value, datetime_format)
if not (low_matches_format and high_matches_format):
Expand Down Expand Up @@ -1131,7 +1149,7 @@ def _get_diff_column_name(self, table_data):
return token.join(components)

def _get_is_datetime(self):
is_column_datetime = self.metadata.columns[self._column_name]['sdtype'] == 'datetime'
is_column_datetime = self.metadata.get_columns()[self._column_name]['sdtype'] == 'datetime'
is_low_datetime = _is_datetime_type(self._low_value)
is_high_datetime = _is_datetime_type(self._high_value)
is_datetime = is_low_datetime and is_high_datetime and is_column_datetime
Expand All @@ -1152,7 +1170,9 @@ def _fit(self, table_data):
self._is_datetime = self._get_is_datetime()
self._transformed_column = self._get_diff_column_name(table_data)
if self._is_datetime:
self._datetime_format = self.metadata.columns[self._column_name].get('datetime_format')
self._datetime_format = self.metadata.get_columns()[self._column_name].get(
'datetime_format'
)
self._low_value = cast_to_datetime64(
self._low_value, datetime_format=self._datetime_format
)
Expand Down Expand Up @@ -1432,14 +1452,15 @@ def __init__(self, column_names):

@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
metadata = Metadata._convert_to_unified_metadata(metadata)
column_names = kwargs.get('column_names')
keys = set()
if isinstance(metadata.primary_key, tuple):
keys.update(metadata.primary_key)
if isinstance(metadata.get_primary_key(), tuple):
keys.update(metadata.get_primary_key())
else:
keys.add(metadata.primary_key)
keys.add(metadata.get_primary_key())

for key in metadata.alternate_keys:
for key in metadata.get_alternate_keys():
if isinstance(key, tuple):
keys.update(key)
else:
Expand Down
Loading
Loading