Skip to content

Commit

Permalink
Add warning when unable to turn off rounding scheme for a column (#2279)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Nov 12, 2024
1 parent 0fe0123 commit 4d56fe7
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 5 deletions.
25 changes: 22 additions & 3 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def _validate_transformers(self, column_name_to_transformer):
f"Transformer for column '{column}' has already been fit on data."
)

def _warn_for_update_transformers(self, column_name_to_transformer):
"""Raise warnings for update_transformers.
def _warn_quality_and_performance(self, column_name_to_transformer):
"""Raise warning if the quality/performance may be impacted.
Args:
column_name_to_transformer (dict):
Expand All @@ -259,6 +259,24 @@ def _warn_for_update_transformers(self, column_name_to_transformer):
'might impact the quality of your synthetic data.'
)

def _warn_unable_to_enforce_rounding(self, column_name_to_transformer):
if self.enforce_rounding:
invalid_columns = []
for column, transformer in column_name_to_transformer.items():
if (
hasattr(transformer, 'learn_rounding_scheme')
and not transformer.learn_rounding_scheme
):
invalid_columns.append(column)

if invalid_columns:
warnings.warn(
f'Unable to turn off rounding scheme for column(s) {invalid_columns}, '
'because the overall synthesizer is enforcing rounding. We '
"recommend setting the synthesizer's 'enforce_rounding' "
'parameter to False.'
)

def update_transformers(self, column_name_to_transformer):
"""Update any of the transformers assigned to each of the column names.
Expand All @@ -267,7 +285,8 @@ def update_transformers(self, column_name_to_transformer):
Dict mapping column names to transformers to be used for that column.
"""
self._validate_transformers(column_name_to_transformer)
self._warn_for_update_transformers(column_name_to_transformer)
self._warn_quality_and_performance(column_name_to_transformer)
self._warn_unable_to_enforce_rounding(column_name_to_transformer)
self._data_processor.update_transformers(column_name_to_transformer)
if self._fitted:
msg = 'For this change to take effect, please refit the synthesizer using `fit`.'
Expand Down
4 changes: 2 additions & 2 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _fit_model(self, processed_data):
warnings.filterwarnings('ignore', module='scipy')
self._model.fit(processed_data)

def _warn_for_update_transformers(self, column_name_to_transformer):
"""Raise warnings for update_transformers.
def _warn_quality_and_performance(self, column_name_to_transformer):
"""Raise warning if the quality/performance may be impacted.
Args:
column_name_to_transformer (dict):
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,22 @@ def test_fit_int_primary_key_regex_includes_zero(synthesizer_class, regex):
)
with pytest.raises(SynthesizerInputError, match=message):
instance.fit(data)


@patch('sdv.single_table.base.warnings')
def test_update_transformers(warning_mock):
"""Test the proper warning is raised."""
# Setup
data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests')

# Run
synthesizer = GaussianCopulaSynthesizer(metadata)
synthesizer.auto_assign_transformers(data)
synthesizer.update_transformers({'amenities_fee': FloatFormatter(learn_rounding_scheme=False)})

# Assert
warning_mock.warn.assert_called_once_with(
"Unable to turn off rounding scheme for column(s) ['amenities_fee'], because the overall "
"synthesizer is enforcing rounding. We recommend setting the synthesizer's "
"'enforce_rounding' parameter to False."
)
25 changes: 25 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,31 @@ def test_update_transformers(self):
assert isinstance(field_transformers['col1'], GaussianNormalizer)
assert isinstance(field_transformers['col2'], GaussianNormalizer)

def test_update_transformers_warns_rounding(self):
"""Test warning is raised if model cannot round."""
# Setup
column_name_to_transformer = {
'col1': GaussianNormalizer(learn_rounding_scheme=False),
'col2': GaussianNormalizer(learn_rounding_scheme=True),
'col3': GaussianNormalizer(learn_rounding_scheme=False),
}
metadata = Metadata()
instance = BaseSingleTableSynthesizer(metadata)
instance._validate_transformers = MagicMock()
instance._warn_quality_and_performance = MagicMock()
instance._data_processor = MagicMock()
instance.enforce_rounding = True
instance._fitted = False

# Run and Assert
warn_msg = re.escape(
"Unable to turn off rounding scheme for column(s) ['col1', 'col3'], "
'because the overall synthesizer is enforcing rounding. We recommend '
"setting the synthesizer's 'enforce_rounding' parameter to False."
)
with pytest.warns(UserWarning, match=warn_msg):
instance.update_transformers(column_name_to_transformer)

@patch('sdv.single_table.base.DataProcessor')
def test__set_random_state(self, mock_data_processor):
"""Test that ``_model.set_random_state`` is being called with the input value.
Expand Down

0 comments on commit 4d56fe7

Please sign in to comment.