Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metadata isn't validating sdtypes in a column relationship (public SDV only) #1889

Merged
merged 6 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
'copulas>=0.9.0',
'ctgan>=0.9.0',
'deepecho>=0.5',
'rdt>=1.10.0',
'rdt @ git+https://github.com/sdv-dev/RDT@main',
'sdmetrics>=0.13.0',
]

Expand Down
14 changes: 9 additions & 5 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime

import pandas as pd
from rdt.transformers._validators import AddressValidator, GPSValidator
from rdt.transformers.pii.anonymization import SDTYPE_ANONYMIZERS, is_faker_function

from sdv._utils import (
Expand All @@ -17,7 +18,6 @@
from sdv.metadata.errors import InvalidMetadataError
from sdv.metadata.metadata_upgrader import convert_metadata
from sdv.metadata.utils import read_json, validate_file_does_not_exist
from sdv.metadata.validation import validate_address_sdtypes, validate_gps_sdtypes
from sdv.metadata.visualization import (
create_columns_node, create_summarized_columns_node, visualize_graph)

Expand Down Expand Up @@ -104,8 +104,8 @@ class SingleTableMetadata:
set(_REFERENCE_TO_SDTYPE.items()) - set(_SDTYPES_WITHOUT_SUBSTRINGS.items()))

_COLUMN_RELATIONSHIP_TYPES = {
'address': validate_address_sdtypes,
'gps': validate_gps_sdtypes,
'address': AddressValidator.validate,
'gps': GPSValidator.validate,
}

METADATA_SPEC_VERSION = 'SINGLE_TABLE_V1'
Expand Down Expand Up @@ -822,7 +822,7 @@ def _validate_column_relationship(self, relationship):

except ImportError:
warnings.warn(
f"The metadata contains a column relationship of type '{relationship_type}'. "
f"The metadata contains a column relationship of type '{relationship_type}' "
f'which requires the {relationship_type} add-on. '
'This relationship will be ignored. For higher quality data in this'
' relationship, please inquire about the SDV Enterprise tier.'
Expand Down Expand Up @@ -889,6 +889,7 @@ def _validate_all_column_relationships(self, column_relationships):
# Validate each individual relationship
errors = []
self._valid_column_relationships = deepcopy(column_relationships)
invalid_indexes = []
for idx, relationship in enumerate(column_relationships):
try:
self._append_error(
Expand All @@ -897,7 +898,10 @@ def _validate_all_column_relationships(self, column_relationships):
relationship,
)
except ImportError:
self._valid_column_relationships.pop(idx)
invalid_indexes.append(idx)

for idx in reversed(invalid_indexes):
del self._valid_column_relationships[idx]

if errors:
raise InvalidMetadataError(
Expand Down
68 changes: 0 additions & 68 deletions sdv/metadata/validation.py

This file was deleted.

2 changes: 1 addition & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_minimum_versions(dependencies, python_version):
for dependency in dependencies:
if '@' in dependency:
name, url = dependency.split(' @ ')
min_versions[name] = f'{name} @ {url}'
min_versions[name] = f'{url}#egg={name}'
continue

req = Requirement(dependency)
Expand Down
37 changes: 34 additions & 3 deletions tests/integration/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes):
"\nInvalid value for 'computer_representation' 'value' for column 'col8'."
"\nInvalid datetime format string '%1-%Y-%m-%d-%' for datetime column 'col9'."
"\nInvalid regex format string '[A-{6}' for id column 'col10'."
"\nColumn relationships have following errors:\nColumns ['col1', 'col2'] have "
"unsupported sdtypes for column relationship type 'address'.\nUnknown column "
"relationship type 'fake_relationship'. Must be one of ['address', 'gps']."
'\nColumn relationships have following errors:\n'
"Column 'col1' has an unsupported sdtype 'id'.\n"
"Column 'col2' has an unsupported sdtype 'numerical'.\n"
'Please provide a column that is compatible with Address data.\n'
"Unknown column relationship type 'fake_relationship'. Must be one of ['address', 'gps']."
)
# Run / Assert
with pytest.raises(InvalidMetadataError, match=err_msg):
Expand Down Expand Up @@ -516,3 +518,32 @@ def test_update_columns_metadata_invalid_kwargs_combination():
'col2': {'pii': True}
}
)


def test_column_relationship_validation():
"""Test that column relationships are validated correctly."""
# Setup
metadata = SingleTableMetadata.load_from_dict({
'columns': {
'user_city': {'sdtype': 'city'},
'user_zip': {'sdtype': 'postcode'},
'user_value': {'sdtype': 'unknown'}
},
'column_relationships': [
{
'type': 'address',
'column_names': ['user_city', 'user_zip', 'user_value']
}
]
})

expected_message = re.escape(
'The following errors were found in the metadata:\n\n'
'Column relationships have following errors:\n'
"Column 'user_value' has an unsupported sdtype 'unknown'.\n"
'Please provide a column that is compatible with Address data.'
)

# Run and Assert
with pytest.raises(InvalidMetadataError, match=expected_message):
metadata.validate()
2 changes: 1 addition & 1 deletion tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ def test_metadata_updated_warning(method, kwargs):
'id': {'sdtype': 'id'},
'date': {'sdtype': 'datetime'},
'city': {'sdtype': 'city'},
'country': {'sdtype': 'country'}
'country': {'sdtype': 'country_code'}
},
},
'arrival': {
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,8 @@ def test_metadata_updated_warning(method, kwargs):
'col 1': {'sdtype': 'id'},
'col 2': {'sdtype': 'id'},
'col 3': {'sdtype': 'categorical'},
'col 4': {'sdtype': 'city'},
'col 5': {'sdtype': 'country'},
'city': {'sdtype': 'city'},
'country': {'sdtype': 'country_code'},
}
})
expected_message = re.escape(
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_get_minimum_versions():
"pandas>=1.2.0,<2;python_version<'3.10'",
"pandas>=1.3.0,<2;python_version>='3.10'",
'humanfriendly>=8.2,<11',
'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas'
'pandas @ git+https://github.com/pandas-dev/pandas.git@master'
]

# Run
Expand All @@ -25,12 +25,12 @@ def test_get_minimum_versions():
# Assert
expected_versions_39 = [
'numpy==1.20.0',
'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
R-Palazzo marked this conversation as resolved.
Show resolved Hide resolved
'humanfriendly==8.2',
]
expected_versions_310 = [
'numpy==1.23.3',
'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas',
'humanfriendly==8.2',
]

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test__detect_multi_column_transformers_address(self, transformers_mock):
]
})
metadata.validate()
metadata._valid_column_relationships = metadata.column_relationships
dp = DataProcessor(SingleTableMetadata())
dp.metadata = metadata
dp._locales = ['en_US', 'en_GB']
Expand Down Expand Up @@ -99,6 +100,7 @@ def test__detect_multi_column_transformers_gps(self, transformers_mock):
]
})
metadata.validate()
metadata._valid_column_relationships = metadata.column_relationships
dp = DataProcessor(SingleTableMetadata())
dp.metadata = metadata
dp._locales = ['en_US', 'en_GB']
Expand Down Expand Up @@ -140,6 +142,7 @@ def test__detect_multi_column_transformers_gps_address(self, transformers_mock):
]
})
metadata.validate()
metadata._valid_column_relationships = metadata.column_relationships
dp = DataProcessor(SingleTableMetadata())
dp.metadata = metadata
dp._locales = ['en_US', 'en_GB']
Expand Down
Loading
Loading