diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 9bda9644e..a8d33a203 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -537,9 +537,9 @@ def update_transformers(self, column_name_to_transformer): "'RegexGenerator' instead." ) - warnings.filterwarnings('ignore', module='rdt') - self._hyper_transformer.update_transformers(column_name_to_transformer) - warnings.resetwarnings() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', module='rdt.hyper_transformer') + self._hyper_transformer.update_transformers(column_name_to_transformer) def _fit_hyper_transformer(self, data): """Create and return a new ``rdt.HyperTransformer`` instance. diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 862c977a7..e0670c2c5 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -1,4 +1,5 @@ import re +import warnings from unittest.mock import Mock, call, patch import numpy as np @@ -7,7 +8,8 @@ from rdt.errors import ConfigNotSetError from rdt.errors import NotFittedError as RDTNotFittedError from rdt.transformers import ( - AnonymizedFaker, FloatFormatter, IDGenerator, UniformEncoder, UnixTimestampEncoder) + AnonymizedFaker, FloatFormatter, GaussianNormalizer, IDGenerator, UniformEncoder, + UnixTimestampEncoder) from sdv.constraints.errors import MissingConstraintColumnError from sdv.constraints.tabular import Positive, ScalarRange @@ -1199,6 +1201,18 @@ def test_update_transformers_not_fitted(self): with pytest.raises(NotFittedError, match=error_msg): dp.update_transformers({'column': None}) + def test_update_transformers_ignores_rdt_refit_warning(self): + """Test silencing hypertransformer refit warning (replaced by SDV warning elsewhere)""" + metadata = SingleTableMetadata() + metadata.add_column('col1', sdtype='numerical') + metadata.add_column('col2', sdtype='numerical') + + dp = DataProcessor(metadata) + dp.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) + with warnings.catch_warnings(): + warnings.simplefilter('error') + dp.update_transformers({'col1': GaussianNormalizer()}) + def test_update_transformers_for_key(self): """Test when ``transformer`` is not ``AnonymizedFaker`` or ``RegexGenerator`` for keys.""" # Setup