From 18cd2e53daebf0e4db158470a3b8e0708d71d044 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:58:16 +0100 Subject: [PATCH] Raise `UserWarnings` for Unused Numerical Distributions when using `GaussianCopulaSynthesizer` (#2301) --- sdv/single_table/copulagan.py | 7 ++---- sdv/single_table/copulas.py | 6 ++--- sdv/single_table/utils.py | 12 +++++----- .../integration/single_table/test_copulas.py | 22 +++++++++++++++++++ tests/unit/single_table/test_copulagan.py | 11 +++++----- tests/unit/single_table/test_copulas.py | 15 +++++++------ tests/unit/single_table/test_utils.py | 16 ++++++++++++++ 7 files changed, 62 insertions(+), 27 deletions(-) diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index 1713cef45..d7ce656d2 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -8,8 +8,8 @@ from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer from sdv.single_table.utils import ( - log_numerical_distributions_error, validate_numerical_distributions, + warn_missing_numerical_distributions, ) LOGGER = logging.getLogger(__name__) @@ -204,10 +204,7 @@ def _fit(self, processed_data): processed_data (pandas.DataFrame): Data to be learned. """ - log_numerical_distributions_error( - self.numerical_distributions, processed_data.columns, LOGGER - ) - + warn_missing_numerical_distributions(self.numerical_distributions, processed_data.columns) gaussian_normalizer_config = self._create_gaussian_normalizer_config(processed_data) self._gaussian_normalizer_hyper_transformer = rdt.HyperTransformer() self._gaussian_normalizer_hyper_transformer.set_config(gaussian_normalizer_config) diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index d629a88fe..5dfa4fedc 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -17,9 +17,9 @@ from sdv.single_table.base import BaseSingleTableSynthesizer from sdv.single_table.utils import ( flatten_dict, - log_numerical_distributions_error, unflatten_dict, validate_numerical_distributions, + warn_missing_numerical_distributions, ) LOGGER = logging.getLogger(__name__) @@ -132,9 +132,7 @@ def _fit(self, processed_data): processed_data (pandas.DataFrame): Data to be learned. """ - log_numerical_distributions_error( - self.numerical_distributions, processed_data.columns, LOGGER - ) + warn_missing_numerical_distributions(self.numerical_distributions, processed_data.columns) self._num_rows = self._learn_num_rows(processed_data) numerical_distributions = self._get_numerical_distributions(processed_data) self._model = self._initialize_model(numerical_distributions) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 527f845ec..c11138c33 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -330,12 +330,12 @@ def validate_numerical_distributions(numerical_distributions, metadata_columns): ) -def log_numerical_distributions_error(numerical_distributions, processed_data_columns, logger): - """Log error when numerical distributions columns don't exist anymore.""" +def warn_missing_numerical_distributions(numerical_distributions, processed_data_columns): + """Raise an `UserWarning` when numerical distribution columns don't exist anymore.""" unseen_columns = numerical_distributions.keys() - set(processed_data_columns) for column in unseen_columns: - logger.info( - f"Requested distribution '{numerical_distributions[column]}' " - f"cannot be applied to column '{column}' because it no longer " - 'exists after preprocessing.' + warnings.warn( + f"Cannot use distribution '{numerical_distributions[column]}' for column " + f"'{column}' because the column is not statistically modeled.", + UserWarning, ) diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 96250ca67..b43b139b1 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -500,3 +500,25 @@ def test_support_nullable_pandas_dtypes(): assert (synthetic_data.dtypes == data.dtypes).all() assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True) assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True) + + +def test_user_warning_for_unused_numerical_distribution(): + """Ensure that a `UserWarning` is raised when a numerical distribution is not applied. + + This test verifies that the synthesizer warns the user if a specified numerical + distribution is not used because the corresponding column does not exist or is not + modeled after preprocessing. + """ + # Setup + data, metadata = download_demo('single_table', 'fake_hotel_guests') + synthesizer = GaussianCopulaSynthesizer( + metadata, numerical_distributions={'credit_card_number': 'beta'} + ) + + # Run and Assert + message = ( + "Cannot use distribution 'beta' for column 'credit_card_number' because the column is not " + 'statistically modeled.' + ) + with pytest.warns(UserWarning, match=message): + synthesizer.fit(data) diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index 354ca841c..7bc22058b 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -263,10 +263,10 @@ def test__create_gaussian_normalizer_config(self, mock_rdt): assert config == expected_config assert mock_rdt.transformers.GaussianNormalizer.call_args_list == expected_calls - @patch('sdv.single_table.copulagan.LOGGER') + @patch('sdv.single_table.utils.warnings') @patch('sdv.single_table.copulagan.CTGANSynthesizer._fit') @patch('sdv.single_table.copulagan.rdt') - def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger): + def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_warnings): """Test a message is logged. A message should be logged if the columns passed in ``numerical_distributions`` @@ -284,10 +284,11 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger): instance._fit(processed_data) # Assert - mock_logger.info.assert_called_once_with( - "Requested distribution 'gamma' cannot be applied to column 'col' " - 'because it no longer exists after preprocessing.' + warning_message = ( + "Cannot use distribution 'gamma' for column 'col' because the column is not " + 'statistically modeled.' ) + mock_warnings.warn.assert_called_once_with(warning_message, UserWarning) @patch('sdv.single_table.copulagan.CTGANSynthesizer._fit') @patch('sdv.single_table.copulagan.rdt') diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index e9a52903b..c6ca4dd24 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -159,11 +159,11 @@ def test_get_parameters(self): 'default_distribution': 'beta', } - @patch('sdv.single_table.copulas.LOGGER') - def test__fit_logging(self, mock_logger): - """Test a message is logged. + @patch('sdv.single_table.utils.warnings') + def test__fit_warning_numerical_distributions(self, mock_warnings): + """Test that a warning is shown when fitting numerical distributions on a dropped column. - A message should be logged if the columns passed in ``numerical_distributions`` + A warning message should be printed if the columns passed in ``numerical_distributions`` were renamed/dropped during preprocessing. """ # Setup @@ -180,10 +180,11 @@ def test__fit_logging(self, mock_logger): instance._fit(processed_data) # Assert - mock_logger.info.assert_called_once_with( - "Requested distribution 'gamma' cannot be applied to column 'col' " - 'because it no longer exists after preprocessing.' + warning_message = ( + "Cannot use distribution 'gamma' for column 'col' because the column is not " + 'statistically modeled.' ) + mock_warnings.warn.assert_called_once_with(warning_message, UserWarning) @patch('sdv.single_table.copulas.warnings') @patch('sdv.single_table.copulas.multivariate') diff --git a/tests/unit/single_table/test_utils.py b/tests/unit/single_table/test_utils.py index b612c9c7b..dbf84cdaa 100644 --- a/tests/unit/single_table/test_utils.py +++ b/tests/unit/single_table/test_utils.py @@ -14,6 +14,7 @@ handle_sampling_error, unflatten_dict, validate_file_path, + warn_missing_numerical_distributions, ) @@ -328,3 +329,18 @@ def test_validate_file_path(mock_open): assert output_path in result assert none_result is None mock_open.assert_called_once_with(result, 'w+') + + +def test_warn_missing_numerical_distributions(): + """Test the warn_missing_numerical_distributions function.""" + # Setup + numerical_distributions = {'age': 'beta', 'height': 'uniform'} + processed_data_columns = ['height', 'weight'] + + # Run and Assert + message = ( + "Cannot use distribution 'beta' for column 'age' because the column is not " + 'statistically modeled.' + ) + with pytest.warns(UserWarning, match=message): + warn_missing_numerical_distributions(numerical_distributions, processed_data_columns)