From c06a273a4150c10a290dbac5574f33034f056e14 Mon Sep 17 00:00:00 2001 From: rwedge Date: Tue, 3 Oct 2023 17:04:04 -0400 Subject: [PATCH 1/3] use catch_warnings to filter temporarily --- sdv/data_processing/data_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 9bda9644e..deb48ea1b 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -537,9 +537,9 @@ def update_transformers(self, column_name_to_transformer): "'RegexGenerator' instead." ) - warnings.filterwarnings('ignore', module='rdt') - self._hyper_transformer.update_transformers(column_name_to_transformer) - warnings.resetwarnings() + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message=rdt.HyperTransformer._REFIT_MESSAGE, module='rdt') + self._hyper_transformer.update_transformers(column_name_to_transformer) def _fit_hyper_transformer(self, data): """Create and return a new ``rdt.HyperTransformer`` instance. From 3d6831675e7d3631f08e560a3c878cb69b80ccf9 Mon Sep 17 00:00:00 2001 From: rwedge Date: Tue, 3 Oct 2023 18:59:18 -0400 Subject: [PATCH 2/3] add test and lint --- sdv/data_processing/data_processor.py | 3 ++- .../unit/data_processing/test_data_processor.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index deb48ea1b..3efc02a20 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -538,7 +538,8 @@ def update_transformers(self, column_name_to_transformer): ) with warnings.catch_warnings(): - warnings.filterwarnings('ignore', message=rdt.HyperTransformer._REFIT_MESSAGE, module='rdt') + msg = rdt.HyperTransformer._REFIT_MESSAGE + warnings.filterwarnings('ignore', message=msg, module='rdt') self._hyper_transformer.update_transformers(column_name_to_transformer) def _fit_hyper_transformer(self, data): diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 862c977a7..e0670c2c5 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -1,4 +1,5 @@ import re +import warnings from unittest.mock import Mock, call, patch import numpy as np @@ -7,7 +8,8 @@ from rdt.errors import ConfigNotSetError from rdt.errors import NotFittedError as RDTNotFittedError from rdt.transformers import ( - AnonymizedFaker, FloatFormatter, IDGenerator, UniformEncoder, UnixTimestampEncoder) + AnonymizedFaker, FloatFormatter, GaussianNormalizer, IDGenerator, UniformEncoder, + UnixTimestampEncoder) from sdv.constraints.errors import MissingConstraintColumnError from sdv.constraints.tabular import Positive, ScalarRange @@ -1199,6 +1201,18 @@ def test_update_transformers_not_fitted(self): with pytest.raises(NotFittedError, match=error_msg): dp.update_transformers({'column': None}) + def test_update_transformers_ignores_rdt_refit_warning(self): + """Test silencing hypertransformer refit warning (replaced by SDV warning elsewhere)""" + metadata = SingleTableMetadata() + metadata.add_column('col1', sdtype='numerical') + metadata.add_column('col2', sdtype='numerical') + + dp = DataProcessor(metadata) + dp.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) + with warnings.catch_warnings(): + warnings.simplefilter('error') + dp.update_transformers({'col1': GaussianNormalizer()}) + def test_update_transformers_for_key(self): """Test when ``transformer`` is not ``AnonymizedFaker`` or ``RegexGenerator`` for keys.""" # Setup From e63e07c83e04a009008dc5e8de6cc6b67d64585d Mon Sep 17 00:00:00 2001 From: rwedge Date: Thu, 5 Oct 2023 10:36:08 -0400 Subject: [PATCH 3/3] ignore rdt.hyper_transformer instead of specific rdt message --- sdv/data_processing/data_processor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 3efc02a20..a8d33a203 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -538,8 +538,7 @@ def update_transformers(self, column_name_to_transformer): ) with warnings.catch_warnings(): - msg = rdt.HyperTransformer._REFIT_MESSAGE - warnings.filterwarnings('ignore', message=msg, module='rdt') + warnings.filterwarnings('ignore', module='rdt.hyper_transformer') self._hyper_transformer.update_transformers(column_name_to_transformer) def _fit_hyper_transformer(self, data):