diff --git a/pyproject.toml b/pyproject.toml index b7ebb6f7e..6324b5b2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ 'copulas>=0.9.0', 'ctgan>=0.9.0', 'deepecho>=0.5', - 'rdt>=1.10.0', + 'rdt @ git+https://github.com/sdv-dev/RDT@main', 'sdmetrics>=0.13.0', ] diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 2d16205ee..5707a9b46 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -8,6 +8,7 @@ from datetime import datetime import pandas as pd +from rdt.transformers._validators import AddressValidator, GPSValidator from rdt.transformers.pii.anonymization import SDTYPE_ANONYMIZERS, is_faker_function from sdv._utils import ( @@ -17,7 +18,6 @@ from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist -from sdv.metadata.validation import validate_address_sdtypes, validate_gps_sdtypes from sdv.metadata.visualization import ( create_columns_node, create_summarized_columns_node, visualize_graph) @@ -104,8 +104,8 @@ class SingleTableMetadata: set(_REFERENCE_TO_SDTYPE.items()) - set(_SDTYPES_WITHOUT_SUBSTRINGS.items())) _COLUMN_RELATIONSHIP_TYPES = { - 'address': validate_address_sdtypes, - 'gps': validate_gps_sdtypes, + 'address': AddressValidator.validate, + 'gps': GPSValidator.validate, } METADATA_SPEC_VERSION = 'SINGLE_TABLE_V1' @@ -822,7 +822,7 @@ def _validate_column_relationship(self, relationship): except ImportError: warnings.warn( - f"The metadata contains a column relationship of type '{relationship_type}'. " + f"The metadata contains a column relationship of type '{relationship_type}' " f'which requires the {relationship_type} add-on. ' 'This relationship will be ignored. For higher quality data in this' ' relationship, please inquire about the SDV Enterprise tier.' @@ -889,6 +889,7 @@ def _validate_all_column_relationships(self, column_relationships): # Validate each individual relationship errors = [] self._valid_column_relationships = deepcopy(column_relationships) + invalid_indexes = [] for idx, relationship in enumerate(column_relationships): try: self._append_error( @@ -897,7 +898,10 @@ def _validate_all_column_relationships(self, column_relationships): relationship, ) except ImportError: - self._valid_column_relationships.pop(idx) + invalid_indexes.append(idx) + + for idx in reversed(invalid_indexes): + del self._valid_column_relationships[idx] if errors: raise InvalidMetadataError( diff --git a/sdv/metadata/validation.py b/sdv/metadata/validation.py deleted file mode 100644 index be71de839..000000000 --- a/sdv/metadata/validation.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Column relationship validation functions.""" -import rdt -from rdt.errors import TransformerInputError - -from sdv.metadata.errors import InvalidMetadataError - - -def _check_import_address_transformers(): - """Check that the address transformers can be imported.""" - error_message = ( - 'You must have SDV Enterprise with the address add-on to use the address features' - ) - if not hasattr(rdt.transformers, 'address'): - raise ImportError(error_message) - - has_randomlocationgenerator = hasattr(rdt.transformers.address, 'RandomLocationGenerator') - has_regionalanonymizer = hasattr(rdt.transformers.address, 'RegionalAnonymizer') - if not has_randomlocationgenerator or not has_regionalanonymizer: - raise ImportError(error_message) - - -def validate_address_sdtypes(columns_to_sdtypes): - """Validate sdtypes for address column relationship. - - Args: - columns_to_sdtypes (dict): - Dictionary mapping column names to sdtypes. - - Raises: - ``InvalidMetadataError`` if column sdtypes are invalid for the relationship. - """ - _check_import_address_transformers() - try: - rdt.transformers.address.RandomLocationGenerator._validate_sdtypes(columns_to_sdtypes) - except TransformerInputError as error: - raise InvalidMetadataError(str(error)) - - -def _check_import_gps_transformers(): - """Check that the gps transformers can be imported.""" - error_message = ( - 'You must have SDV Enterprise with the gps add-on to use the gps features' - ) - if not hasattr(rdt.transformers, 'gps'): - raise ImportError(error_message) - - has_randomlocationgenerator = hasattr(rdt.transformers.gps, 'RandomLocationGenerator') - has_metroareaanonymizer = hasattr(rdt.transformers.gps, 'MetroAreaAnonymizer') - has_gpsnoiser = hasattr(rdt.transformers.gps, 'GPSNoiser') - if not has_randomlocationgenerator or not has_metroareaanonymizer or not has_gpsnoiser: - raise ImportError(error_message) - - -def validate_gps_sdtypes(columns_to_sdtypes): - """Validate sdtypes for gps column relationship. - - Args: - columns_to_sdtypes (dict): - Dictionary mapping column names to sdtypes. - - Raises: - ``InvalidMetadataError`` if column sdtypes are invalid for the relationship. - """ - _check_import_gps_transformers() - try: - rdt.transformers.gps.RandomLocationGenerator._validate_sdtypes(columns_to_sdtypes) - except TransformerInputError as error: - raise InvalidMetadataError(str(error)) diff --git a/tasks.py b/tasks.py index 4e224887e..5d300165b 100644 --- a/tasks.py +++ b/tasks.py @@ -44,7 +44,7 @@ def _get_minimum_versions(dependencies, python_version): for dependency in dependencies: if '@' in dependency: name, url = dependency.split(' @ ') - min_versions[name] = f'{name} @ {url}' + min_versions[name] = f'{url}#egg={name}' continue req = Requirement(dependency) diff --git a/tests/integration/metadata/test_single_table.py b/tests/integration/metadata/test_single_table.py index 8b1477052..4aaaa68ea 100644 --- a/tests/integration/metadata/test_single_table.py +++ b/tests/integration/metadata/test_single_table.py @@ -177,9 +177,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes): "\nInvalid value for 'computer_representation' 'value' for column 'col8'." "\nInvalid datetime format string '%1-%Y-%m-%d-%' for datetime column 'col9'." "\nInvalid regex format string '[A-{6}' for id column 'col10'." - "\nColumn relationships have following errors:\nColumns ['col1', 'col2'] have " - "unsupported sdtypes for column relationship type 'address'.\nUnknown column " - "relationship type 'fake_relationship'. Must be one of ['address', 'gps']." + '\nColumn relationships have following errors:\n' + "Column 'col1' has an unsupported sdtype 'id'.\n" + "Column 'col2' has an unsupported sdtype 'numerical'.\n" + 'Please provide a column that is compatible with Address data.\n' + "Unknown column relationship type 'fake_relationship'. Must be one of ['address', 'gps']." ) # Run / Assert with pytest.raises(InvalidMetadataError, match=err_msg): @@ -516,3 +518,32 @@ def test_update_columns_metadata_invalid_kwargs_combination(): 'col2': {'pii': True} } ) + + +def test_column_relationship_validation(): + """Test that column relationships are validated correctly.""" + # Setup + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'user_city': {'sdtype': 'city'}, + 'user_zip': {'sdtype': 'postcode'}, + 'user_value': {'sdtype': 'unknown'} + }, + 'column_relationships': [ + { + 'type': 'address', + 'column_names': ['user_city', 'user_zip', 'user_value'] + } + ] + }) + + expected_message = re.escape( + 'The following errors were found in the metadata:\n\n' + 'Column relationships have following errors:\n' + "Column 'user_value' has an unsupported sdtype 'unknown'.\n" + 'Please provide a column that is compatible with Address data.' + ) + + # Run and Assert + with pytest.raises(InvalidMetadataError, match=expected_message): + metadata.validate() diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4ea365a09..6f063d909 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1511,7 +1511,7 @@ def test_metadata_updated_warning(method, kwargs): 'id': {'sdtype': 'id'}, 'date': {'sdtype': 'datetime'}, 'city': {'sdtype': 'city'}, - 'country': {'sdtype': 'country'} + 'country': {'sdtype': 'country_code'} }, }, 'arrival': { diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index c7ca19998..f6af7fc76 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -680,8 +680,8 @@ def test_metadata_updated_warning(method, kwargs): 'col 1': {'sdtype': 'id'}, 'col 2': {'sdtype': 'id'}, 'col 3': {'sdtype': 'categorical'}, - 'col 4': {'sdtype': 'city'}, - 'col 5': {'sdtype': 'country'}, + 'city': {'sdtype': 'city'}, + 'country': {'sdtype': 'country_code'}, } }) expected_message = re.escape( diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c78986cf3..35cff9bd9 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -15,7 +15,7 @@ def test_get_minimum_versions(): "pandas>=1.2.0,<2;python_version<'3.10'", "pandas>=1.3.0,<2;python_version>='3.10'", 'humanfriendly>=8.2,<11', - 'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas' + 'pandas @ git+https://github.com/pandas-dev/pandas.git@master' ] # Run @@ -25,12 +25,12 @@ def test_get_minimum_versions(): # Assert expected_versions_39 = [ 'numpy==1.20.0', - 'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', + 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', 'humanfriendly==8.2', ] expected_versions_310 = [ 'numpy==1.23.3', - 'pandas @ git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', + 'git+https://github.com/pandas-dev/pandas.git@master#egg=pandas', 'humanfriendly==8.2', ] diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 673b48b2a..3ef0af498 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -67,6 +67,7 @@ def test__detect_multi_column_transformers_address(self, transformers_mock): ] }) metadata.validate() + metadata._valid_column_relationships = metadata.column_relationships dp = DataProcessor(SingleTableMetadata()) dp.metadata = metadata dp._locales = ['en_US', 'en_GB'] @@ -99,6 +100,7 @@ def test__detect_multi_column_transformers_gps(self, transformers_mock): ] }) metadata.validate() + metadata._valid_column_relationships = metadata.column_relationships dp = DataProcessor(SingleTableMetadata()) dp.metadata = metadata dp._locales = ['en_US', 'en_GB'] @@ -140,6 +142,7 @@ def test__detect_multi_column_transformers_gps_address(self, transformers_mock): ] }) metadata.validate() + metadata._valid_column_relationships = metadata.column_relationships dp = DataProcessor(SingleTableMetadata()) dp.metadata = metadata dp._locales = ['en_US', 'en_GB'] diff --git a/tests/unit/metadata/test_validation.py b/tests/unit/metadata/test_validation.py deleted file mode 100644 index 07bb58ffe..000000000 --- a/tests/unit/metadata/test_validation.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Test Single Table Metadata.""" -import re -from unittest.mock import Mock, patch - -import pytest -from rdt.errors import TransformerInputError - -from sdv.metadata.errors import InvalidMetadataError -from sdv.metadata.validation import ( - _check_import_address_transformers, _check_import_gps_transformers, validate_address_sdtypes, - validate_gps_sdtypes) - - -def test__check_import_address_transformers_without_address_module(): - """Test ``_check_import_address_transformers`` when address module doesn't exist.""" - # Run and Assert - expected_message = ( - 'You must have SDV Enterprise with the address add-on to use the address features' - ) - with pytest.raises(ImportError, match=expected_message): - _check_import_address_transformers() - - -@patch('rdt.transformers') -def test__check_import_address_transformers_without_premium_features(mock_transformers): - """Test ``_check_import_address_transformers`` when the user doesn't have the transformers.""" - # Setup - mock_address = Mock() - del mock_address.RandomLocationGenerator - del mock_address.RegionalAnonymizer - mock_transformers.address = mock_address - - # Run and Assert - expected_message = ( - 'You must have SDV Enterprise with the address add-on to use the address features' - ) - with pytest.raises(ImportError, match=expected_message): - _check_import_address_transformers() - - -@patch('sdv.metadata.validation._check_import_address_transformers') -@patch('rdt.transformers') -def test_validate_address_sdtypes(mock_transformers, mock_check_import): - """Test address sdtype validation.""" - # Setup - columns_to_sdtypes = { - 'col_1': {'sdtype': 'id'}, - 'col_2': {'sdtype': 'numerical'}, - 'col_3': {'sdtype': 'state'} - } - mock_validate_sdtypes = mock_transformers.address.RandomLocationGenerator._validate_sdtypes - - # Run - validate_address_sdtypes(columns_to_sdtypes) - - # Asserts - mock_check_import.assert_called_once() - mock_validate_sdtypes.assert_called_once_with(columns_to_sdtypes) - - -@patch('sdv.metadata.validation._check_import_address_transformers') -@patch('rdt.transformers') -def test_validate_address_sdtypes_error(mock_transformers, mock_check_import): - """Test address sdtype validation.""" - # Setup - columns_to_sdtypes = { - 'col_1': {'sdtype': 'id'}, - 'col_2': {'sdtype': 'numerical'}, - 'col_3': {'sdtype': 'state'} - } - mock_validate_sdtypes = mock_transformers.address.RandomLocationGenerator._validate_sdtypes - mock_validate_sdtypes.side_effect = TransformerInputError('Error') - - # Run and Assert - expected_message = re.escape('Error') - with pytest.raises(InvalidMetadataError, match=expected_message): - validate_address_sdtypes(columns_to_sdtypes) - - -def test__check_import_gps_transformers_without_gps_module(): - """Test ``_check_import_gps_transformers`` when gps module doesn't exist.""" - # Run and Assert - expected_message = ( - 'You must have SDV Enterprise with the gps add-on to use the gps features' - ) - with pytest.raises(ImportError, match=expected_message): - _check_import_gps_transformers() - - -@patch('rdt.transformers') -def test__check_import_gps_transformers_without_premium_features(mock_transformers): - """Test ``_check_import_gps_transformers`` when the user doesn't have the transformers.""" - # Setup - mock_gps = Mock() - del mock_gps.RandomLocationGenerator - del mock_gps.MetroAreaAnonymizer - del mock_gps.GPSNoiser - mock_transformers.gps = mock_gps - - # Run and Assert - expected_message = ( - 'You must have SDV Enterprise with the gps add-on to use the gps features' - ) - with pytest.raises(ImportError, match=expected_message): - _check_import_gps_transformers() - - -@patch('sdv.metadata.validation._check_import_gps_transformers') -@patch('rdt.transformers') -def test_validate_gps_sdtypes(mock_transformers, mock_check_import): - """Test gps sdtype validation.""" - # Setup - columns_to_sdtypes = { - 'col_1': {'sdtype': 'id'}, - 'col_2': {'sdtype': 'numerical'}, - 'col_3': {'sdtype': 'city'} - } - mock_validate_sdtypes = mock_transformers.gps.RandomLocationGenerator._validate_sdtypes - - # Run - validate_gps_sdtypes(columns_to_sdtypes) - - # Asserts - mock_check_import.assert_called_once() - mock_validate_sdtypes.assert_called_once_with(columns_to_sdtypes) - - -@patch('sdv.metadata.validation._check_import_gps_transformers') -@patch('rdt.transformers') -def test_validate_gps_sdtypes_error(mock_transformers, mock_check_import): - """Test gps sdtype validation.""" - # Setup - columns_to_sdtypes = { - 'col_1': {'sdtype': 'id'}, - 'col_2': {'sdtype': 'numerical'}, - 'col_3': {'sdtype': 'city'} - } - mock_validate_sdtypes = mock_transformers.gps.RandomLocationGenerator._validate_sdtypes - mock_validate_sdtypes.side_effect = TransformerInputError('Error') - - # Run - expected_message = re.escape('Error') - with pytest.raises(InvalidMetadataError, match=expected_message): - validate_gps_sdtypes(columns_to_sdtypes) - - # Asserts - mock_check_import.assert_called_once() diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index fd3384562..b70729969 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -132,7 +132,7 @@ def test__init__column_relationship_warning(self): metadata.add_column_relationship('nesreca', 'gps', ['lat', 'lon']) expected_warning = ( - "The metadata contains a column relationship of type 'gps'. " + "The metadata contains a column relationship of type 'gps' " 'which requires the gps add-on. This relationship will be ignored. For higher' ' quality data in this relationship, please inquire about the SDV Enterprise tier.' )