From d27e3f09f86754e479c44c5bc07b9353b42c5777 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 18 Apr 2024 18:25:19 +0200 Subject: [PATCH] Fix unit tests and move to logging module --- sdv/_utils.py | 2 + sdv/logging/__init__.py | 0 sdv/logging/sdv_logger_config.yml | 15 +++++ sdv/logging/utils.py | 91 ++++++++++++++++++++++++++++ sdv/multi_table/base.py | 43 +++++++------ sdv/sdv_logger.yml | 15 ----- sdv/single_table/base.py | 24 ++++---- tests/unit/multi_table/test_base.py | 40 ++++++++---- tests/unit/multi_table/test_hma.py | 8 +-- tests/unit/single_table/test_base.py | 18 +++--- 10 files changed, 182 insertions(+), 74 deletions(-) create mode 100644 sdv/logging/__init__.py create mode 100644 sdv/logging/sdv_logger_config.yml create mode 100644 sdv/logging/utils.py delete mode 100644 sdv/sdv_logger.yml diff --git a/sdv/_utils.py b/sdv/_utils.py index 577600b8d..df9bcc2f4 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -1,4 +1,5 @@ """Miscellaneous utility functions.""" +import contextlib import operator import uuid import warnings @@ -8,6 +9,7 @@ from pathlib import Path import pandas as pd +import yaml from pandas.core.tools.datetimes import _guess_datetime_format_for_array from sdv import version diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sdv/logging/sdv_logger_config.yml b/sdv/logging/sdv_logger_config.yml new file mode 100644 index 000000000..c983ee36e --- /dev/null +++ b/sdv/logging/sdv_logger_config.yml @@ -0,0 +1,15 @@ +log_registry: 'local' +version: 1 +loggers: + SingleTableSynthesizer: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log + MultiTableSynthesizer: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py new file mode 100644 index 000000000..756ea31ff --- /dev/null +++ b/sdv/logging/utils.py @@ -0,0 +1,91 @@ +import contextlib +import logging +import logging.config +from functools import lru_cache +from pathlib import Path + +import yaml + + +def get_logger_config(): + """Return a dictionary with the logging configuration.""" + logging_path = Path(__file__).parent + with open(logging_path / 'sdv_logger_config.yml', 'r') as f: + logger_conf = yaml.safe_load(f) + + # Logfile to be in this same directory + for logger in logger_conf.get('loggers', {}).values(): + handler = logger.get('handlers', {}) + if handler.get('filename') == 'sdv_logs.log': + handler['filename'] = logging_path / handler['filename'] + + return logger_conf + + +@contextlib.contextmanager +def disable_single_table_logger(): + """Temporarily disables logging for the single table synthesizers. + + This context manager temporarily removes all handlers associated with + the ``SingleTableSynthesizer`` logger, disabling logging for that module + within the current context. After the context exits, the + removed handlers are restored to the logger. + """ + # Logging without ``SingleTableSynthesizer`` + single_table_logger = logging.getLogger('SingleTableSynthesizer') + handlers = single_table_logger.handlers + for handler in handlers: + single_table_logger.removeHandler(handler) + + try: + yield + finally: + for handler in handlers: + single_table_logger.addHandler(handler) + + +@lru_cache() +def get_logger(logger_name): + """Get a logger instance with the specified name and configuration. + + This function retrieves or creates a logger instance with the specified name + and applies configuration settings based on the logger's name and the logging + configuration. + + Args: + logger_name (str): + The name of the logger to retrieve or create. + + Returns: + logging.Logger: + A logger instance configured according to the logging configuration + and the specific settings for the given logger name. + """ + logger_conf = get_logger_config() + logger = logging.getLogger(logger_name) + if logger_name in logger_conf.get('loggers'): + formatter = None + config = logger_conf.get('loggers').get(logger_name) + log_level = getattr(logging, config.get('level', 'INFO')) + if config.get('format'): + formatter = logging.Formatter(config.get('format')) + + logger.setLevel(log_level) + logger.propagate = config.get('propagate', False) + handler = config.get('handlers') + handlers = handler.get('class') + handlers = [handlers] if isinstance(handlers, str) else handlers + for handler_class in handlers: + if handler_class == 'logging.FileHandler': + logfile = handler.get('filename') + file_handler = logging.FileHandler(logfile) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + elif handler in ('consoleHandler', 'StreamingHandler'): + ch = logging.StreamHandler() + ch.setLevel(log_level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index d3d630e5d..da34cb82a 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -19,14 +19,10 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.logging.utils import get_logger, disable_single_table_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer -logger_config_file = Path(__file__).parent.parent -with open(logger_config_file / 'sdv_logger.yml', 'r') as f: - logger_conf = yaml.safe_load(f) - -logging.config.dictConfig(logger_conf) -LOGGER = logging.getLogger('BaseMultiTableSynthesizer') +SYNTHESIZER_LOGGER = get_logger('MultiTableSynthesizer') class BaseMultiTableSynthesizer: @@ -66,13 +62,14 @@ def _set_temp_numpy_seed(self): np.random.set_state(initial_state) def _initialize_models(self): - for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) - self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, - locales=self.locales, - **synthesizer_parameters - ) + with disable_single_table_logger(): + for table_name, table_metadata in self.metadata.tables.items(): + synthesizer_parameters = self._table_parameters.get(table_name, {}) + self._table_synthesizers[table_name] = self._synthesizer( + metadata=table_metadata, + locales=self.locales, + **synthesizer_parameters + ) def _get_pbar_args(self, **kwargs): """Return a dictionary with the updated keyword args for a progress bar.""" @@ -123,7 +120,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nInstance:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -396,7 +393,7 @@ def fit_processed_data(self, processed_data): total_rows += len(table) total_columns += len(table.columns) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit processed data\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -413,8 +410,10 @@ def fit_processed_data(self, processed_data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - augmented_data = self._augment_tables(processed_data) - self._model_tables(augmented_data) + with disable_single_table_logger(): + augmented_data = self._augment_tables(processed_data) + self._model_tables(augmented_data) + self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') self._fitted_sdv_version = getattr(version, 'public', None) @@ -434,7 +433,7 @@ def fit(self, data): total_rows += len(table) total_columns += len(table.columns) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -482,7 +481,7 @@ def sample(self, scale=1.0): raise SynthesizerInputError( f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") - with self._set_temp_numpy_seed(): + with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) total_rows = 0 @@ -491,7 +490,7 @@ def sample(self, scale=1.0): total_rows += len(table) total_columns += len(table.columns) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -674,7 +673,7 @@ def save(self, filepath): with open(filepath, 'wb') as output: cloudpickle.dump(self, output) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSave:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -704,7 +703,7 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nLoad\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' diff --git a/sdv/sdv_logger.yml b/sdv/sdv_logger.yml deleted file mode 100644 index 20be0b1df..000000000 --- a/sdv/sdv_logger.yml +++ /dev/null @@ -1,15 +0,0 @@ -log_registry: 'local' -version: 1 -handlers: - file: - class: logging.FileHandler - filename: sdv_logs.log -loggers: - BaseSingleTableSynthesizer: - level: INFO - handlers: [file] - propagate: no - BaseMultiTableSynthesizer: - level: INFO - handlers: [file] - propagate: no diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 2d98fd62a..d1576654c 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -4,7 +4,6 @@ import functools import inspect import logging -import logging.config import math import operator import os @@ -18,7 +17,6 @@ import numpy as np import pandas as pd import tqdm -import yaml from copulas.multivariate import GaussianMultivariate from sdv import version @@ -27,14 +25,12 @@ from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.logging.utils import get_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path -logger_config_file = Path(__file__).parent.parent -with open(logger_config_file / 'sdv_logger.yml', 'r') as f: - logger_conf = yaml.safe_load(f) -logging.config.dictConfig(logger_conf) -LOGGER = logging.getLogger('BaseSingleTableSynthesizer') +LOGGER = logging.getLogger(__name__) +SYNTHESIZER_LOGGER = get_logger('SingleTableSynthesizer') COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 @@ -116,7 +112,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nInstance:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -407,7 +403,7 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit processed data\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -438,7 +434,7 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -468,7 +464,7 @@ def save(self, filepath): filepath (str): Path where the synthesizer instance will be serialized. """ - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSave:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -500,7 +496,7 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nLoad\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -884,7 +880,7 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file output_file_path, show_progress_bar=show_progress_bar ) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -900,6 +896,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file self._synthesizer_id, ) + return sampled_data + def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None): """Sample rows with conditions. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 99330715e..1e717cb68 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -829,14 +829,16 @@ def test_fit_processed_data(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None ) - data = Mock() - data.copy.return_value = data + processed_data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert - instance._augment_tables.assert_called_once_with(data) + instance._augment_tables.assert_called_once_with(processed_data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted @@ -847,10 +849,13 @@ def test_fit_processed_data_empty_table(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None ) - data = pd.DataFrame() + processed_data = { + 'table1': pd.DataFrame(), + 'table2': pd.DataFrame() + } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance._fit.assert_not_called() @@ -866,7 +871,10 @@ def test_fit_processed_data_raises_version_error(self): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + processed_data = { + 'table1': pd.DataFrame(), + 'table2': pd.DataFrame() + } # Run and Assert error_msg = ( @@ -875,7 +883,7 @@ def test_fit_processed_data_raises_version_error(self): 'Please create a new synthesizer.' ) with pytest.raises(VersionError, match=error_msg): - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance.preprocess.assert_not_called() @@ -891,7 +899,10 @@ def test_fit(self, mock_validate_foreign_keys_not_null): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run BaseMultiTableSynthesizer.fit(instance, data) @@ -910,7 +921,10 @@ def test_fit_raises_version_error(self): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run and Assert error_msg = ( @@ -991,7 +1005,11 @@ def test_sample(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - instance._sample = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } + instance._sample = Mock(return_value=data) # Run instance.sample(scale=1.5) diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 3dab339ba..c40e7b080 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -674,7 +674,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) @@ -823,7 +823,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): table_name='child_uniform', table_parameters={'default_distribution': 'uniform'} ) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) distributions = synthesizer._get_distributions() # Run estimation @@ -953,7 +953,7 @@ def test__estimate_num_columns_to_be_modeled(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) @@ -1068,7 +1068,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 509cee456..b4c106b80 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -348,8 +348,7 @@ def test_fit_processed_data(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() - processed_data.empty = False + processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) # Run BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) @@ -368,7 +367,7 @@ def test_fit_processed_data_raises_version_error(self): _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() + processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True instance._fitted = True @@ -392,17 +391,17 @@ def test_fit(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() + data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True instance._fitted = True # Run - BaseSingleTableSynthesizer.fit(instance, processed_data) + BaseSingleTableSynthesizer.fit(instance, data) # Assert assert instance._random_state_set is False instance._data_processor.reset_sampling.assert_called_once_with() - instance._preprocess.assert_called_once_with(processed_data) + instance._preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance._preprocess.return_value) instance._check_metadata_updated.assert_called_once() @@ -417,7 +416,7 @@ def test_fit_raises_version_error(self): _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() + data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True instance._fitted = True @@ -428,7 +427,7 @@ def test_fit_raises_version_error(self): 'create a new synthesizer.' ) with pytest.raises(VersionError, match=error_msg): - BaseSingleTableSynthesizer.fit(instance, processed_data) + BaseSingleTableSynthesizer.fit(instance, data) def test__validate_constraints(self): """Test that ``_validate_constraints`` calls ``fit`` and returns any errors.""" @@ -1371,6 +1370,7 @@ def test_sample(self): output_file_path = 'temp.csv' instance = Mock() instance.get_metadata.return_value._constraints = False + instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) # Run result = BaseSingleTableSynthesizer.sample( @@ -1389,7 +1389,7 @@ def test_sample(self): 'temp.csv', show_progress_bar=True ) - assert result == instance._sample_with_progress_bar.return_value + pd.testing.assert_frame_equal(result, pd.DataFrame({'col': [1, 2, 3]})) def test__validate_conditions_unseen_columns(self): """Test that conditions are within the ``data_processor`` fields."""