diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index e6f106dae..d8a9b463c 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -51,25 +51,6 @@ def load_from_dict(cls, metadata_dict, single_table_name=None): instance._set_metadata_dict(metadata_dict, single_table_name) return instance - @classmethod - def load_from_single_table_metadata(cls, single_metadata_table, table_name=None): - """Return a unified Metadata object from a legacy SingleTableMetadata object. - - Args: - single_metadata_table (SingleTableMetadata): - ``SingleTableMetadata`` object to be converted to a ``Metadata`` object. - table_name (string): - The name of the table that will be stored in the ``Metadata object. - - Returns: - Instance of ``Metadata``. - """ - if not isinstance(single_metadata_table, SingleTableMetadata): - raise InvalidMetadataError('Cannot convert given legacy metadata') - instance = cls() - instance._set_metadata_dict(single_metadata_table.to_dict()) - return instance - def _set_metadata_dict(self, metadata, single_table_name=None): """Set a ``metadata`` dictionary to the current instance. @@ -89,18 +70,12 @@ def _set_metadata_dict(self, metadata, single_table_name=None): else: if single_table_name is None: single_table_name = 'default_table_name' - self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) def _convert_to_single_table(self): - is_multi_table = len(self.tables) > 1 - if is_multi_table: + if len(self.tables) > 1: raise InvalidMetadataError( 'Metadata contains more than one table, use a MultiTableSynthesizer instead.' ) - if len(self.tables) == 0: - return SingleTableMetadata() - - single_table_metadata = next(iter(self.tables.values())) - return single_table_metadata + return next(iter(self.tables.values()), SingleTableMetadata()) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 07d527a75..9e094f43b 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -26,6 +26,7 @@ SynthesizerInputError, ) from sdv.logging import disable_single_table_logger, get_sdv_logger +from sdv.metadata.metadata import Metadata from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') @@ -71,8 +72,9 @@ def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): synthesizer_parameters = self._table_parameters.get(table_name, {}) + metadata = Metadata.load_from_dict(table_metadata.to_dict()) self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, locales=self.locales, **synthesizer_parameters + metadata=metadata, locales=self.locales, **synthesizer_parameters ) self._table_synthesizers[table_name]._data_processor.table_name = table_name diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 951b07627..ddceca005 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -44,6 +44,11 @@ COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 +DEPRECATION_MSG = ( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." +) + class BaseSynthesizer: """Base class for all ``Synthesizers``. @@ -100,15 +105,15 @@ def _check_metadata_updated(self): def __init__( self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=['en_US'] ): - single_metadata = metadata - if isinstance(single_metadata, Metadata): - single_metadata = single_metadata._convert_to_single_table() + self.metadata = metadata + if isinstance(metadata, Metadata): + self.metadata = metadata._convert_to_single_table() self.real_metadata = metadata - elif isinstance(single_metadata, SingleTableMetadata): - self.real_metadata = Metadata.load_from_single_table_metadata(metadata) + elif isinstance(metadata, SingleTableMetadata): + self.real_metadata = Metadata.load_from_dict(metadata.to_dict()) + warnings.warn(DEPRECATION_MSG, FutureWarning) self._validate_inputs(enforce_min_max_values, enforce_rounding) - self.metadata = single_metadata self.metadata.validate() self._check_metadata_updated() self.enforce_min_max_values = enforce_min_max_values diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index ef7f0b464..ec9b081be 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1245,6 +1245,11 @@ def test_metadata_updated_no_warning(self, tmp_path): # Run 1 with warnings.catch_warnings(record=True) as captured_warnings: + warnings.filterwarnings( + 'ignore', + message=".*The 'SingleTableMetadata' is deprecated.*", + category=DeprecationWarning, + ) warnings.simplefilter('always') instance = HMASynthesizer(metadata) instance.fit(data) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index dc9c328d5..92470976e 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -13,6 +13,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sampling import Condition from sdv.single_table import ( CopulaGANSynthesizer, @@ -587,7 +588,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): initialization, but is saved to a file before fitting. """ # Setup - metadata_from_dict = SingleTableMetadata().load_from_dict({ + metadata_from_dict = Metadata().load_from_dict({ 'columns': { 'col 1': {'sdtype': 'numerical'}, 'col 2': {'sdtype': 'numerical'}, @@ -610,8 +611,8 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): assert len(captured_warnings) == 0 # Run 2 - metadata_detect = SingleTableMetadata() - metadata_detect.detect_from_dataframe(data) + metadata_detect = Metadata() + metadata_detect.detect_from_dataframes({'mock_table': data}) file_name = tmp_path / 'singletable.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: @@ -624,7 +625,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): # Run 3 instance = BaseSingleTableSynthesizer(metadata_detect) - metadata_detect.update_column('col 1', sdtype='categorical') + metadata_detect.update_column('mock_table', 'col 1', sdtype='categorical') file_name = tmp_path / 'singletable_2.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: @@ -650,18 +651,26 @@ def test_metadata_updated_warning_detect(mock__fit): }) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) - expected_message = re.escape( + expected_user_message = ( "We strongly recommend saving the metadata using 'save_to_json' for replicability" ' in future SDV versions.' ) + expected_deprecation_message = ( + "The 'SingleTableMetadata' is deprecated. " + "Please use the new 'Metadata' class for synthesizers." + ) # Run - with pytest.warns(UserWarning, match=expected_message) as record: + with warnings.catch_warnings(record=True) as record: instance = BaseSingleTableSynthesizer(metadata) instance.fit(data) # Assert - assert len(record) == 1 + assert len(record) == 2 + assert record[0].category is FutureWarning + assert str(record[0].message) == expected_deprecation_message + assert record[1].category is UserWarning + assert str(record[1].message) == expected_user_message parametrization = [ diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 3c493606f..229a0162d 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -18,6 +18,7 @@ SynthesizerInputError, VersionError, ) +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer @@ -53,14 +54,18 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call( - metadata=instance.metadata.tables['nesreca'], + metadata=ANY, default_distribution='gamma', locales=locales, ), - call(metadata=instance.metadata.tables['oseba'], locales=locales), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales), + call(metadata=ANY, locales=locales), + call(metadata=ANY, locales=locales), ]) + for call_args in instance._synthesizer.call_args_list: + metadata_arg = call_args[1].get('metadata', None) + assert isinstance(metadata_arg, Metadata) + def test__get_pbar_args(self): """Test that ``_get_pbar_args`` returns a dictionary with disable opposite to verbose.""" # Setup diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 2e98a5816..cfd4f187d 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -238,7 +238,7 @@ def test_get_metadata(self): result = instance.get_metadata() # Assert - assert result.to_dict() == Metadata.load_from_single_table_metadata(metadata).to_dict() + assert result._convert_to_single_table().to_dict() == metadata.to_dict() assert isinstance(result, Metadata) def test_validate_context_columns_unique_per_sequence_key(self): diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index ffee92b31..4707877b5 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -120,6 +120,22 @@ def test___init__( 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) + def test__init__with_old_metadata_future_warning(self): + """Test that future warning is thrown when using `SingleTableMetadata`""" + # Setup + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'a': {'sdtype': 'categorical'}, + } + }) + warn_msg = re.escape( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." + ) + # Run and Assert + with pytest.warns(FutureWarning, match=warn_msg): + BaseSingleTableSynthesizer(metadata) + def test___init__with_unified_metadata(self): """Test initialization with unified metadata.""" # Setup