Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jul 22, 2024
1 parent c8caca8 commit 6010629
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 45 deletions.
29 changes: 2 additions & 27 deletions sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,6 @@ def load_from_dict(cls, metadata_dict, single_table_name=None):
instance._set_metadata_dict(metadata_dict, single_table_name)
return instance

@classmethod
def load_from_single_table_metadata(cls, single_metadata_table, table_name=None):
"""Return a unified Metadata object from a legacy SingleTableMetadata object.
Args:
single_metadata_table (SingleTableMetadata):
``SingleTableMetadata`` object to be converted to a ``Metadata`` object.
table_name (string):
The name of the table that will be stored in the ``Metadata object.
Returns:
Instance of ``Metadata``.
"""
if not isinstance(single_metadata_table, SingleTableMetadata):
raise InvalidMetadataError('Cannot convert given legacy metadata')
instance = cls()
instance._set_metadata_dict(single_metadata_table.to_dict())
return instance

def _set_metadata_dict(self, metadata, single_table_name=None):
"""Set a ``metadata`` dictionary to the current instance.
Expand All @@ -89,18 +70,12 @@ def _set_metadata_dict(self, metadata, single_table_name=None):
else:
if single_table_name is None:
single_table_name = 'default_table_name'

self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata)

def _convert_to_single_table(self):
is_multi_table = len(self.tables) > 1
if is_multi_table:
if len(self.tables) > 1:
raise InvalidMetadataError(
'Metadata contains more than one table, use a MultiTableSynthesizer instead.'
)

if len(self.tables) == 0:
return SingleTableMetadata()

single_table_metadata = next(iter(self.tables.values()))
return single_table_metadata
return next(iter(self.tables.values()), SingleTableMetadata())
4 changes: 3 additions & 1 deletion sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand Down Expand Up @@ -71,8 +72,9 @@ def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
metadata = Metadata.load_from_dict(table_metadata.to_dict())
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, locales=self.locales, **synthesizer_parameters
metadata=metadata, locales=self.locales, **synthesizer_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name

Expand Down
17 changes: 11 additions & 6 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251

DEPRECATION_MSG = (
"The 'SingleTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class BaseSynthesizer:
"""Base class for all ``Synthesizers``.
Expand Down Expand Up @@ -100,15 +105,15 @@ def _check_metadata_updated(self):
def __init__(
self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=['en_US']
):
single_metadata = metadata
if isinstance(single_metadata, Metadata):
single_metadata = single_metadata._convert_to_single_table()
self.metadata = metadata
if isinstance(metadata, Metadata):
self.metadata = metadata._convert_to_single_table()
self.real_metadata = metadata
elif isinstance(single_metadata, SingleTableMetadata):
self.real_metadata = Metadata.load_from_single_table_metadata(metadata)
elif isinstance(metadata, SingleTableMetadata):
self.real_metadata = Metadata.load_from_dict(metadata.to_dict())
warnings.warn(DEPRECATION_MSG, FutureWarning)

self._validate_inputs(enforce_min_max_values, enforce_rounding)
self.metadata = single_metadata
self.metadata.validate()
self._check_metadata_updated()
self.enforce_min_max_values = enforce_min_max_values
Expand Down
5 changes: 5 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,11 @@ def test_metadata_updated_no_warning(self, tmp_path):

# Run 1
with warnings.catch_warnings(record=True) as captured_warnings:
warnings.filterwarnings(
'ignore',
message=".*The 'SingleTableMetadata' is deprecated.*",
category=DeprecationWarning,
)
warnings.simplefilter('always')
instance = HMASynthesizer(metadata)
instance.fit(data)
Expand Down
23 changes: 16 additions & 7 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sdv.datasets.demo import download_demo
from sdv.errors import SamplingError, SynthesizerInputError, VersionError
from sdv.metadata import SingleTableMetadata
from sdv.metadata.metadata import Metadata
from sdv.sampling import Condition
from sdv.single_table import (
CopulaGANSynthesizer,
Expand Down Expand Up @@ -587,7 +588,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path):
initialization, but is saved to a file before fitting.
"""
# Setup
metadata_from_dict = SingleTableMetadata().load_from_dict({
metadata_from_dict = Metadata().load_from_dict({
'columns': {
'col 1': {'sdtype': 'numerical'},
'col 2': {'sdtype': 'numerical'},
Expand All @@ -610,8 +611,8 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path):
assert len(captured_warnings) == 0

# Run 2
metadata_detect = SingleTableMetadata()
metadata_detect.detect_from_dataframe(data)
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes({'mock_table': data})
file_name = tmp_path / 'singletable.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
Expand All @@ -624,7 +625,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path):

# Run 3
instance = BaseSingleTableSynthesizer(metadata_detect)
metadata_detect.update_column('col 1', sdtype='categorical')
metadata_detect.update_column('mock_table', 'col 1', sdtype='categorical')
file_name = tmp_path / 'singletable_2.json'
metadata_detect.save_to_json(file_name)
with warnings.catch_warnings(record=True) as captured_warnings:
Expand All @@ -650,18 +651,26 @@ def test_metadata_updated_warning_detect(mock__fit):
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
expected_message = re.escape(
expected_user_message = (
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
' in future SDV versions.'
)
expected_deprecation_message = (
"The 'SingleTableMetadata' is deprecated. "
"Please use the new 'Metadata' class for synthesizers."
)

# Run
with pytest.warns(UserWarning, match=expected_message) as record:
with warnings.catch_warnings(record=True) as record:
instance = BaseSingleTableSynthesizer(metadata)
instance.fit(data)

# Assert
assert len(record) == 1
assert len(record) == 2
assert record[0].category is FutureWarning
assert str(record[0].message) == expected_deprecation_message
assert record[1].category is UserWarning
assert str(record[1].message) == expected_user_message


parametrization = [
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
SynthesizerInputError,
VersionError,
)
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.metadata.single_table import SingleTableMetadata
from sdv.multi_table.base import BaseMultiTableSynthesizer
Expand Down Expand Up @@ -53,14 +54,18 @@ def test__initialize_models(self):
}
instance._synthesizer.assert_has_calls([
call(
metadata=instance.metadata.tables['nesreca'],
metadata=ANY,
default_distribution='gamma',
locales=locales,
),
call(metadata=instance.metadata.tables['oseba'], locales=locales),
call(metadata=instance.metadata.tables['upravna_enota'], locales=locales),
call(metadata=ANY, locales=locales),
call(metadata=ANY, locales=locales),
])

for call_args in instance._synthesizer.call_args_list:
metadata_arg = call_args[1].get('metadata', None)
assert isinstance(metadata_arg, Metadata)

def test__get_pbar_args(self):
"""Test that ``_get_pbar_args`` returns a dictionary with disable opposite to verbose."""
# Setup
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_get_metadata(self):
result = instance.get_metadata()

# Assert
assert result.to_dict() == Metadata.load_from_single_table_metadata(metadata).to_dict()
assert result._convert_to_single_table().to_dict() == metadata.to_dict()
assert isinstance(result, Metadata)

def test_validate_context_columns_unique_per_sequence_key(self):
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ def test___init__(
'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

def test__init__with_old_metadata_future_warning(self):
"""Test that future warning is thrown when using `SingleTableMetadata`"""
# Setup
metadata = SingleTableMetadata.load_from_dict({
'columns': {
'a': {'sdtype': 'categorical'},
}
})
warn_msg = re.escape(
"The 'SingleTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)
# Run and Assert
with pytest.warns(FutureWarning, match=warn_msg):
BaseSingleTableSynthesizer(metadata)

def test___init__with_unified_metadata(self):
"""Test initialization with unified metadata."""
# Setup
Expand Down

0 comments on commit 6010629

Please sign in to comment.