From 75319075f7a662c816722963e518c3053a9641d2 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Tue, 9 Apr 2024 18:13:12 +0200 Subject: [PATCH 01/19] Generate unique id for synthesizers --- tests/unit/test__utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test__utils.py b/tests/unit/test__utils.py index d340a991c..d56d69144 100644 --- a/tests/unit/test__utils.py +++ b/tests/unit/test__utils.py @@ -659,4 +659,8 @@ def test_generate_synthesizer_id(mock_version, mock_uuid): result = generate_synthesizer_id(synthesizer) # Assert +<<<<<<< HEAD assert result == 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' +======= + assert result == 'BaseSingleTableSynthesizer_1.0.0_990d1231a5f5' +>>>>>>> 36f3cb15 (Generate unique id for synthesizers) From 00ac0adc29e5df4e941123065ada075778960c31 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 17 Apr 2024 19:06:45 +0200 Subject: [PATCH 02/19] Finalize loggers --- sdv/metadata/multi_table.py | 16 ++++++ sdv/metadata/single_table.py | 10 ++++ sdv/multi_table/base.py | 104 +++++++++++++++++++++++++++++++++++ sdv/sdv_logger.yml | 15 +++++ sdv/single_table/base.py | 86 ++++++++++++++++++++++++++++- 5 files changed, 229 insertions(+), 2 deletions(-) create mode 100644 sdv/sdv_logger.yml diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 194dc7ba7..d5ceb49d3 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1,5 +1,6 @@ """Multi Table Metadata.""" +import datetime import json import logging import warnings @@ -1040,6 +1041,21 @@ def save_to_json(self, filepath): """ validate_file_does_not_exist(filepath) metadata = self.to_dict() + total_columns = 0 + for table in self.tables: + total_columns += len(table.columns) + + LOGGER.info( + '\nMetadata Save:\n' + ' Timestamp: %s\n' + ' Statistics about the metadata:\n' + ' Total number of tables: %s', + ' Total number of columns: %s' + ' Total number of relationships: %s', + datetime.datetime.now(), + total_columns, + len(self.relationships) + ) with open(filepath, 'w', encoding='utf-8') as metadata_file: json.dump(metadata, metadata_file, indent=4) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index ce81ad3a0..2c7952b9d 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1206,6 +1206,16 @@ def save_to_json(self, filepath): validate_file_does_not_exist(filepath) metadata = self.to_dict() metadata['METADATA_SPEC_VERSION'] = self.METADATA_SPEC_VERSION + LOGGER.info( + '\nMetadata Save:\n' + ' Timestamp: %s\n' + ' Statistics about the metadata:\n' + ' Total number of tables: 1', + ' Total number of columns: %s' + ' Total number of relationships: 0', + datetime.now(), + len(self.columns) + ) with open(filepath, 'w', encoding='utf-8') as metadata_file: json.dump(metadata, metadata_file, indent=4) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 751740b3e..d3d630e5d 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -2,13 +2,16 @@ import contextlib import datetime import inspect +import logging import operator import warnings from collections import defaultdict from copy import deepcopy +from pathlib import Path import cloudpickle import numpy as np +import yaml from tqdm import tqdm from sdv import version @@ -18,6 +21,13 @@ from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError from sdv.single_table.copulas import GaussianCopulaSynthesizer +logger_config_file = Path(__file__).parent.parent +with open(logger_config_file / 'sdv_logger.yml', 'r') as f: + logger_conf = yaml.safe_load(f) + +logging.config.dictConfig(logger_conf) +LOGGER = logging.getLogger('BaseMultiTableSynthesizer') + class BaseMultiTableSynthesizer: """Base class for multi table synthesizers. @@ -113,6 +123,15 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) + LOGGER.info( + '\nInstance:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id + ) def _get_root_parents(self): """Get the set of root parents in the graph.""" @@ -371,6 +390,28 @@ def fit_processed_data(self, processed_data): processed_data (dict): Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``. """ + total_rows = 0 + total_columns = 0 + for table in processed_data.values(): + total_rows += len(table) + total_columns += len(table.columns) + + LOGGER.info( + '\nFit processed data\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: %s\n' + ' Table number of rows: %s\n' + ' Table number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(processed_data), + total_rows, + total_columns, + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) augmented_data = self._augment_tables(processed_data) self._model_tables(augmented_data) @@ -387,6 +428,28 @@ def fit(self, data): Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format (before any transformations). """ + total_rows = 0 + total_columns = 0 + for table in data.values(): + total_rows += len(table) + total_columns += len(table.columns) + + LOGGER.info( + '\nFit\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: %s\n' + ' Table number of rows: %s\n' + ' Table number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(data), + total_rows, + total_columns, + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() @@ -422,6 +485,28 @@ def sample(self, scale=1.0): with self._set_temp_numpy_seed(): sampled_data = self._sample(scale=scale) + total_rows = 0 + total_columns = 0 + for table in sampled_data.values(): + total_rows += len(table) + total_columns += len(table.columns) + + LOGGER.info( + '\nSample:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: %s\n' + ' Table number of rows: %s\n' + ' Table number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(sampled_data), + total_rows, + total_columns, + self._synthesizer_id, + ) return sampled_data def get_learned_distributions(self, table_name): @@ -589,6 +674,16 @@ def save(self, filepath): with open(filepath, 'wb') as output: cloudpickle.dump(self, output) + LOGGER.info( + '\nSave:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id, + ) + @classmethod def load(cls, filepath): """Load a multi-table synthesizer from a given path. @@ -609,4 +704,13 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) + LOGGER.info( + '\nLoad\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) return synthesizer diff --git a/sdv/sdv_logger.yml b/sdv/sdv_logger.yml new file mode 100644 index 000000000..20be0b1df --- /dev/null +++ b/sdv/sdv_logger.yml @@ -0,0 +1,15 @@ +log_registry: 'local' +version: 1 +handlers: + file: + class: logging.FileHandler + filename: sdv_logs.log +loggers: + BaseSingleTableSynthesizer: + level: INFO + handlers: [file] + propagate: no + BaseMultiTableSynthesizer: + level: INFO + handlers: [file] + propagate: no diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 0466dbd42..2d98fd62a 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -4,18 +4,21 @@ import functools import inspect import logging +import logging.config import math import operator import os import uuid import warnings from collections import defaultdict +from pathlib import Path import cloudpickle import copulas import numpy as np import pandas as pd import tqdm +import yaml from copulas.multivariate import GaussianMultivariate from sdv import version @@ -26,7 +29,13 @@ from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path -LOGGER = logging.getLogger(__name__) +logger_config_file = Path(__file__).parent.parent +with open(logger_config_file / 'sdv_logger.yml', 'r') as f: + logger_conf = yaml.safe_load(f) + +logging.config.dictConfig(logger_conf) +LOGGER = logging.getLogger('BaseSingleTableSynthesizer') + COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 TMP_FILE_NAME = '.sample.csv.temp' @@ -107,6 +116,15 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) + LOGGER.info( + '\nInstance:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id + ) def set_address_columns(self, column_names, anonymization_level='full'): """Set the address multi-column transformer.""" @@ -389,6 +407,21 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ + LOGGER.info( + '\nFit processed data\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Table number of rows: %s\n' + ' Table number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(processed_data.columns), + len(processed_data), + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: self._fit(processed_data) @@ -405,6 +438,21 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ + LOGGER.info( + '\nFit\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Table number of rows: %s\n' + ' Table number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(data.columns), + len(data), + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False @@ -420,6 +468,15 @@ def save(self, filepath): filepath (str): Path where the synthesizer instance will be serialized. """ + LOGGER.info( + '\nSave:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id, + ) with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -443,6 +500,15 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) + LOGGER.info( + '\nLoad\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) return synthesizer @@ -806,17 +872,33 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file pandas.DataFrame: Sampled data. """ + sample_timestamp = datetime.datetime.now() has_constraints = bool(self._data_processor._constraints) has_batches = batch_size is not None and batch_size != num_rows show_progress_bar = has_constraints or has_batches - return self._sample_with_progress_bar( + sampled_data = self._sample_with_progress_bar( num_rows, max_tries_per_batch, batch_size, output_file_path, show_progress_bar=show_progress_bar ) + LOGGER.info( + '\nSample:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + sample_timestamp, + self.__class__.__name__, + len(sampled_data), + len(sampled_data.columns), + self._synthesizer_id, + ) def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None): From a9010792f33db3a17238943b726f368d7301cf63 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 18 Apr 2024 18:25:19 +0200 Subject: [PATCH 03/19] Fix unit tests and move to logging module --- sdv/_utils.py | 2 + sdv/logging/__init__.py | 0 sdv/logging/sdv_logger_config.yml | 15 +++++ sdv/logging/utils.py | 91 ++++++++++++++++++++++++++++ sdv/multi_table/base.py | 43 +++++++------ sdv/sdv_logger.yml | 15 ----- sdv/single_table/base.py | 24 ++++---- tests/unit/multi_table/test_base.py | 40 ++++++++---- tests/unit/multi_table/test_hma.py | 8 +-- tests/unit/single_table/test_base.py | 18 +++--- 10 files changed, 182 insertions(+), 74 deletions(-) create mode 100644 sdv/logging/__init__.py create mode 100644 sdv/logging/sdv_logger_config.yml create mode 100644 sdv/logging/utils.py delete mode 100644 sdv/sdv_logger.yml diff --git a/sdv/_utils.py b/sdv/_utils.py index 577600b8d..df9bcc2f4 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -1,4 +1,5 @@ """Miscellaneous utility functions.""" +import contextlib import operator import uuid import warnings @@ -8,6 +9,7 @@ from pathlib import Path import pandas as pd +import yaml from pandas.core.tools.datetimes import _guess_datetime_format_for_array from sdv import version diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/sdv/logging/sdv_logger_config.yml b/sdv/logging/sdv_logger_config.yml new file mode 100644 index 000000000..c983ee36e --- /dev/null +++ b/sdv/logging/sdv_logger_config.yml @@ -0,0 +1,15 @@ +log_registry: 'local' +version: 1 +loggers: + SingleTableSynthesizer: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log + MultiTableSynthesizer: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py new file mode 100644 index 000000000..756ea31ff --- /dev/null +++ b/sdv/logging/utils.py @@ -0,0 +1,91 @@ +import contextlib +import logging +import logging.config +from functools import lru_cache +from pathlib import Path + +import yaml + + +def get_logger_config(): + """Return a dictionary with the logging configuration.""" + logging_path = Path(__file__).parent + with open(logging_path / 'sdv_logger_config.yml', 'r') as f: + logger_conf = yaml.safe_load(f) + + # Logfile to be in this same directory + for logger in logger_conf.get('loggers', {}).values(): + handler = logger.get('handlers', {}) + if handler.get('filename') == 'sdv_logs.log': + handler['filename'] = logging_path / handler['filename'] + + return logger_conf + + +@contextlib.contextmanager +def disable_single_table_logger(): + """Temporarily disables logging for the single table synthesizers. + + This context manager temporarily removes all handlers associated with + the ``SingleTableSynthesizer`` logger, disabling logging for that module + within the current context. After the context exits, the + removed handlers are restored to the logger. + """ + # Logging without ``SingleTableSynthesizer`` + single_table_logger = logging.getLogger('SingleTableSynthesizer') + handlers = single_table_logger.handlers + for handler in handlers: + single_table_logger.removeHandler(handler) + + try: + yield + finally: + for handler in handlers: + single_table_logger.addHandler(handler) + + +@lru_cache() +def get_logger(logger_name): + """Get a logger instance with the specified name and configuration. + + This function retrieves or creates a logger instance with the specified name + and applies configuration settings based on the logger's name and the logging + configuration. + + Args: + logger_name (str): + The name of the logger to retrieve or create. + + Returns: + logging.Logger: + A logger instance configured according to the logging configuration + and the specific settings for the given logger name. + """ + logger_conf = get_logger_config() + logger = logging.getLogger(logger_name) + if logger_name in logger_conf.get('loggers'): + formatter = None + config = logger_conf.get('loggers').get(logger_name) + log_level = getattr(logging, config.get('level', 'INFO')) + if config.get('format'): + formatter = logging.Formatter(config.get('format')) + + logger.setLevel(log_level) + logger.propagate = config.get('propagate', False) + handler = config.get('handlers') + handlers = handler.get('class') + handlers = [handlers] if isinstance(handlers, str) else handlers + for handler_class in handlers: + if handler_class == 'logging.FileHandler': + logfile = handler.get('filename') + file_handler = logging.FileHandler(logfile) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + elif handler in ('consoleHandler', 'StreamingHandler'): + ch = logging.StreamHandler() + ch.setLevel(log_level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index d3d630e5d..da34cb82a 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -19,14 +19,10 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.logging.utils import get_logger, disable_single_table_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer -logger_config_file = Path(__file__).parent.parent -with open(logger_config_file / 'sdv_logger.yml', 'r') as f: - logger_conf = yaml.safe_load(f) - -logging.config.dictConfig(logger_conf) -LOGGER = logging.getLogger('BaseMultiTableSynthesizer') +SYNTHESIZER_LOGGER = get_logger('MultiTableSynthesizer') class BaseMultiTableSynthesizer: @@ -66,13 +62,14 @@ def _set_temp_numpy_seed(self): np.random.set_state(initial_state) def _initialize_models(self): - for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) - self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, - locales=self.locales, - **synthesizer_parameters - ) + with disable_single_table_logger(): + for table_name, table_metadata in self.metadata.tables.items(): + synthesizer_parameters = self._table_parameters.get(table_name, {}) + self._table_synthesizers[table_name] = self._synthesizer( + metadata=table_metadata, + locales=self.locales, + **synthesizer_parameters + ) def _get_pbar_args(self, **kwargs): """Return a dictionary with the updated keyword args for a progress bar.""" @@ -123,7 +120,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nInstance:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -396,7 +393,7 @@ def fit_processed_data(self, processed_data): total_rows += len(table) total_columns += len(table.columns) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit processed data\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -413,8 +410,10 @@ def fit_processed_data(self, processed_data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - augmented_data = self._augment_tables(processed_data) - self._model_tables(augmented_data) + with disable_single_table_logger(): + augmented_data = self._augment_tables(processed_data) + self._model_tables(augmented_data) + self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') self._fitted_sdv_version = getattr(version, 'public', None) @@ -434,7 +433,7 @@ def fit(self, data): total_rows += len(table) total_columns += len(table.columns) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -482,7 +481,7 @@ def sample(self, scale=1.0): raise SynthesizerInputError( f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") - with self._set_temp_numpy_seed(): + with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) total_rows = 0 @@ -491,7 +490,7 @@ def sample(self, scale=1.0): total_rows += len(table) total_columns += len(table.columns) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -674,7 +673,7 @@ def save(self, filepath): with open(filepath, 'wb') as output: cloudpickle.dump(self, output) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSave:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -704,7 +703,7 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nLoad\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' diff --git a/sdv/sdv_logger.yml b/sdv/sdv_logger.yml deleted file mode 100644 index 20be0b1df..000000000 --- a/sdv/sdv_logger.yml +++ /dev/null @@ -1,15 +0,0 @@ -log_registry: 'local' -version: 1 -handlers: - file: - class: logging.FileHandler - filename: sdv_logs.log -loggers: - BaseSingleTableSynthesizer: - level: INFO - handlers: [file] - propagate: no - BaseMultiTableSynthesizer: - level: INFO - handlers: [file] - propagate: no diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 2d98fd62a..d1576654c 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -4,7 +4,6 @@ import functools import inspect import logging -import logging.config import math import operator import os @@ -18,7 +17,6 @@ import numpy as np import pandas as pd import tqdm -import yaml from copulas.multivariate import GaussianMultivariate from sdv import version @@ -27,14 +25,12 @@ from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError +from sdv.logging.utils import get_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path -logger_config_file = Path(__file__).parent.parent -with open(logger_config_file / 'sdv_logger.yml', 'r') as f: - logger_conf = yaml.safe_load(f) -logging.config.dictConfig(logger_conf) -LOGGER = logging.getLogger('BaseSingleTableSynthesizer') +LOGGER = logging.getLogger(__name__) +SYNTHESIZER_LOGGER = get_logger('SingleTableSynthesizer') COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 @@ -116,7 +112,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nInstance:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -407,7 +403,7 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit processed data\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -438,7 +434,7 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nFit\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -468,7 +464,7 @@ def save(self, filepath): filepath (str): Path where the synthesizer instance will be serialized. """ - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSave:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -500,7 +496,7 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nLoad\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -884,7 +880,7 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file output_file_path, show_progress_bar=show_progress_bar ) - LOGGER.info( + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' @@ -900,6 +896,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file self._synthesizer_id, ) + return sampled_data + def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size, progress_bar=None, output_file_path=None): """Sample rows with conditions. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 99330715e..1e717cb68 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -829,14 +829,16 @@ def test_fit_processed_data(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None ) - data = Mock() - data.copy.return_value = data + processed_data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert - instance._augment_tables.assert_called_once_with(data) + instance._augment_tables.assert_called_once_with(processed_data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted @@ -847,10 +849,13 @@ def test_fit_processed_data_empty_table(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None ) - data = pd.DataFrame() + processed_data = { + 'table1': pd.DataFrame(), + 'table2': pd.DataFrame() + } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance._fit.assert_not_called() @@ -866,7 +871,10 @@ def test_fit_processed_data_raises_version_error(self): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + processed_data = { + 'table1': pd.DataFrame(), + 'table2': pd.DataFrame() + } # Run and Assert error_msg = ( @@ -875,7 +883,7 @@ def test_fit_processed_data_raises_version_error(self): 'Please create a new synthesizer.' ) with pytest.raises(VersionError, match=error_msg): - BaseMultiTableSynthesizer.fit_processed_data(instance, data) + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance.preprocess.assert_not_called() @@ -891,7 +899,10 @@ def test_fit(self, mock_validate_foreign_keys_not_null): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run BaseMultiTableSynthesizer.fit(instance, data) @@ -910,7 +921,10 @@ def test_fit_raises_version_error(self): _fitted_sdv_enterprise_version=None ) instance.metadata = Mock() - data = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } # Run and Assert error_msg = ( @@ -991,7 +1005,11 @@ def test_sample(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - instance._sample = Mock() + data = { + 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), + 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}) + } + instance._sample = Mock(return_value=data) # Run instance.sample(scale=1.5) diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 3dab339ba..c40e7b080 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -674,7 +674,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) @@ -823,7 +823,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): table_name='child_uniform', table_parameters={'default_distribution': 'uniform'} ) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) distributions = synthesizer._get_distributions() # Run estimation @@ -953,7 +953,7 @@ def test__estimate_num_columns_to_be_modeled(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) @@ -1068,7 +1068,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): ] }) synthesizer = HMASynthesizer(metadata) - synthesizer._finalize = Mock() + synthesizer._finalize = Mock(return_value=data) # Run estimation estimated_num_columns = synthesizer._estimate_num_columns(metadata) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 509cee456..b4c106b80 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -348,8 +348,7 @@ def test_fit_processed_data(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() - processed_data.empty = False + processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) # Run BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) @@ -368,7 +367,7 @@ def test_fit_processed_data_raises_version_error(self): _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() + processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True instance._fitted = True @@ -392,17 +391,17 @@ def test_fit(self): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() + data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True instance._fitted = True # Run - BaseSingleTableSynthesizer.fit(instance, processed_data) + BaseSingleTableSynthesizer.fit(instance, data) # Assert assert instance._random_state_set is False instance._data_processor.reset_sampling.assert_called_once_with() - instance._preprocess.assert_called_once_with(processed_data) + instance._preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance._preprocess.return_value) instance._check_metadata_updated.assert_called_once() @@ -417,7 +416,7 @@ def test_fit_raises_version_error(self): _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, ) - processed_data = Mock() + data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True instance._fitted = True @@ -428,7 +427,7 @@ def test_fit_raises_version_error(self): 'create a new synthesizer.' ) with pytest.raises(VersionError, match=error_msg): - BaseSingleTableSynthesizer.fit(instance, processed_data) + BaseSingleTableSynthesizer.fit(instance, data) def test__validate_constraints(self): """Test that ``_validate_constraints`` calls ``fit`` and returns any errors.""" @@ -1371,6 +1370,7 @@ def test_sample(self): output_file_path = 'temp.csv' instance = Mock() instance.get_metadata.return_value._constraints = False + instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) # Run result = BaseSingleTableSynthesizer.sample( @@ -1389,7 +1389,7 @@ def test_sample(self): 'temp.csv', show_progress_bar=True ) - assert result == instance._sample_with_progress_bar.return_value + pd.testing.assert_frame_equal(result, pd.DataFrame({'col': [1, 2, 3]})) def test__validate_conditions_unseen_columns(self): """Test that conditions are within the ``data_processor`` fields.""" From 0a02de63970a6a804e467c96efcf2de5c0c5eda5 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 18 Apr 2024 18:31:13 +0200 Subject: [PATCH 04/19] Finalize rebase --- tests/unit/test__utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit/test__utils.py b/tests/unit/test__utils.py index d56d69144..d340a991c 100644 --- a/tests/unit/test__utils.py +++ b/tests/unit/test__utils.py @@ -659,8 +659,4 @@ def test_generate_synthesizer_id(mock_version, mock_uuid): result = generate_synthesizer_id(synthesizer) # Assert -<<<<<<< HEAD assert result == 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' -======= - assert result == 'BaseSingleTableSynthesizer_1.0.0_990d1231a5f5' ->>>>>>> 36f3cb15 (Generate unique id for synthesizers) From dd9cc3d6134bb5f32bb2d4979d0afd5e0cc98012 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 18 Apr 2024 18:56:28 +0200 Subject: [PATCH 05/19] Fix lint and bump isort --- sdv/_utils.py | 2 -- sdv/logging/__init__.py | 9 +++++++++ sdv/logging/utils.py | 2 ++ sdv/multi_table/base.py | 5 +---- sdv/single_table/base.py | 2 -- tests/integration/datasets/test_local.py | 2 +- tests/integration/single_table/test_constraints.py | 4 ++-- tests/integration/utils/test_poc.py | 4 ++-- tests/unit/test___init__.py | 2 +- 9 files changed, 18 insertions(+), 14 deletions(-) diff --git a/sdv/_utils.py b/sdv/_utils.py index df9bcc2f4..577600b8d 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -1,5 +1,4 @@ """Miscellaneous utility functions.""" -import contextlib import operator import uuid import warnings @@ -9,7 +8,6 @@ from pathlib import Path import pandas as pd -import yaml from pandas.core.tools.datetimes import _guess_datetime_format_for_array from sdv import version diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index e69de29bb..6b414f9ca 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -0,0 +1,9 @@ +"""Module for configuring loggers within the SDV library.""" + +from sdv.logging.utils import disable_single_table_logger, get_logger, get_logger_config + +__all__ = ( + 'disable_single_table_logger', + 'get_logger', + 'get_logger_config', +) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 756ea31ff..a28c653a2 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -1,3 +1,5 @@ +"""Utilities for configuring logging within the SDV library.""" + import contextlib import logging import logging.config diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index da34cb82a..e0e03b0b9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -2,16 +2,13 @@ import contextlib import datetime import inspect -import logging import operator import warnings from collections import defaultdict from copy import deepcopy -from pathlib import Path import cloudpickle import numpy as np -import yaml from tqdm import tqdm from sdv import version @@ -19,7 +16,7 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError -from sdv.logging.utils import get_logger, disable_single_table_logger +from sdv.logging import disable_single_table_logger, get_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_logger('MultiTableSynthesizer') diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index d1576654c..a70e119e1 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -10,7 +10,6 @@ import uuid import warnings from collections import defaultdict -from pathlib import Path import cloudpickle import copulas @@ -28,7 +27,6 @@ from sdv.logging.utils import get_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path - LOGGER = logging.getLogger(__name__) SYNTHESIZER_LOGGER = get_logger('SingleTableSynthesizer') diff --git a/tests/integration/datasets/test_local.py b/tests/integration/datasets/test_local.py index 135375372..ee1438ff5 100644 --- a/tests/integration/datasets/test_local.py +++ b/tests/integration/datasets/test_local.py @@ -4,7 +4,7 @@ from sdv.datasets.local import save_csvs -@pytest.fixture +@pytest.fixture() def data(): parent = pd.DataFrame(data={ 'id': [0, 1, 2, 3, 4], diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 151fb8d83..9061c67cd 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -30,12 +30,12 @@ def _isinstance_side_effect(*args, **kwargs): ) -@pytest.fixture +@pytest.fixture() def demo_data(): return DEMO_DATA -@pytest.fixture +@pytest.fixture() def demo_metadata(): return DEMO_METADATA diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index e8189b6ea..ac4d4e50d 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -13,7 +13,7 @@ from sdv.utils.poc import drop_unknown_references, simplify_schema -@pytest.fixture +@pytest.fixture() def metadata(): return MultiTableMetadata.load_from_dict( { @@ -45,7 +45,7 @@ def metadata(): ) -@pytest.fixture +@pytest.fixture() def data(): parent = pd.DataFrame(data={ 'id': [0, 1, 2, 3, 4], diff --git a/tests/unit/test___init__.py b/tests/unit/test___init__.py index e94f3b214..b00144c23 100644 --- a/tests/unit/test___init__.py +++ b/tests/unit/test___init__.py @@ -8,7 +8,7 @@ from sdv import _find_addons -@pytest.fixture +@pytest.fixture() def mock_sdv(): sdv_module = sys.modules['sdv'] sdv_mock = Mock() From 505e5f94cbe1d33efa3935ecf615e687dbbc77da Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Thu, 18 Apr 2024 19:14:22 +0200 Subject: [PATCH 06/19] Fix integration tests --- sdv/logging/sdv_logger_config.yml | 12 ++++++++++++ sdv/metadata/multi_table.py | 6 ++++-- sdv/metadata/single_table.py | 4 +++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sdv/logging/sdv_logger_config.yml b/sdv/logging/sdv_logger_config.yml index c983ee36e..64104495f 100644 --- a/sdv/logging/sdv_logger_config.yml +++ b/sdv/logging/sdv_logger_config.yml @@ -13,3 +13,15 @@ loggers: handlers: class: logging.FileHandler filename: sdv_logs.log + MultiTableMetadata: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log + SingleTableMetadata: + level: INFO + propagate: false + handlers: + class: logging.FileHandler + filename: sdv_logs.log diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index d5ceb49d3..f04d0617c 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -12,6 +12,7 @@ from sdv._utils import _cast_to_iterable, _load_data_from_csv from sdv.errors import InvalidDataError +from sdv.logging import get_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.single_table import SingleTableMetadata @@ -20,6 +21,7 @@ create_columns_node, create_summarized_columns_node, visualize_graph) LOGGER = logging.getLogger(__name__) +MULTITABLEMETADATA_LOGGER = get_logger('MultiTableMetadata') WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format'] @@ -1042,10 +1044,10 @@ def save_to_json(self, filepath): validate_file_does_not_exist(filepath) metadata = self.to_dict() total_columns = 0 - for table in self.tables: + for table in self.tables.values(): total_columns += len(table.columns) - LOGGER.info( + MULTITABLEMETADATA_LOGGER.info( '\nMetadata Save:\n' ' Timestamp: %s\n' ' Statistics about the metadata:\n' diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 2c7952b9d..ca96c4c7d 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -16,6 +16,7 @@ _cast_to_iterable, _format_invalid_values_string, _get_datetime_format, _is_boolean_type, _is_datetime_type, _is_numerical_type, _load_data_from_csv, _validate_datetime_format) from sdv.errors import InvalidDataError +from sdv.logging import get_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist @@ -23,6 +24,7 @@ create_columns_node, create_summarized_columns_node, visualize_graph) LOGGER = logging.getLogger(__name__) +SINGLETABLEMETADATA_LOGGER = get_logger('SingleTableMetadata') class SingleTableMetadata: @@ -1206,7 +1208,7 @@ def save_to_json(self, filepath): validate_file_does_not_exist(filepath) metadata = self.to_dict() metadata['METADATA_SPEC_VERSION'] = self.METADATA_SPEC_VERSION - LOGGER.info( + SINGLETABLEMETADATA_LOGGER.info( '\nMetadata Save:\n' ' Timestamp: %s\n' ' Statistics about the metadata:\n' From 2a92c509beb42ad8d455625f2251ebba9716824f Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 19 Apr 2024 12:58:22 +0200 Subject: [PATCH 07/19] Fix tests --- sdv/logging/__init__.py | 6 +++--- sdv/logging/utils.py | 6 +++--- sdv/metadata/multi_table.py | 4 ++-- sdv/metadata/single_table.py | 4 ++-- sdv/multi_table/base.py | 12 ++++++------ sdv/single_table/base.py | 7 ++++--- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index 6b414f9ca..436a1a442 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -1,9 +1,9 @@ """Module for configuring loggers within the SDV library.""" -from sdv.logging.utils import disable_single_table_logger, get_logger, get_logger_config +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config __all__ = ( 'disable_single_table_logger', - 'get_logger', - 'get_logger_config', + 'get_sdv_logger', + 'get_sdv_logger_config', ) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index a28c653a2..09ad89bec 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -9,7 +9,7 @@ import yaml -def get_logger_config(): +def get_sdv_logger_config(): """Return a dictionary with the logging configuration.""" logging_path = Path(__file__).parent with open(logging_path / 'sdv_logger_config.yml', 'r') as f: @@ -47,7 +47,7 @@ def disable_single_table_logger(): @lru_cache() -def get_logger(logger_name): +def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. This function retrieves or creates a logger instance with the specified name @@ -63,7 +63,7 @@ def get_logger(logger_name): A logger instance configured according to the logging configuration and the specific settings for the given logger name. """ - logger_conf = get_logger_config() + logger_conf = get_sdv_logger_config() logger = logging.getLogger(logger_name) if logger_name in logger_conf.get('loggers'): formatter = None diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index f04d0617c..7d561a12b 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -12,7 +12,7 @@ from sdv._utils import _cast_to_iterable, _load_data_from_csv from sdv.errors import InvalidDataError -from sdv.logging import get_logger +from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.single_table import SingleTableMetadata @@ -21,7 +21,7 @@ create_columns_node, create_summarized_columns_node, visualize_graph) LOGGER = logging.getLogger(__name__) -MULTITABLEMETADATA_LOGGER = get_logger('MultiTableMetadata') +MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata') WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format'] diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index ca96c4c7d..4714f8887 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -16,7 +16,7 @@ _cast_to_iterable, _format_invalid_values_string, _get_datetime_format, _is_boolean_type, _is_datetime_type, _is_numerical_type, _load_data_from_csv, _validate_datetime_format) from sdv.errors import InvalidDataError -from sdv.logging import get_logger +from sdv.logging import get_sdv_logger from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata_upgrader import convert_metadata from sdv.metadata.utils import read_json, validate_file_does_not_exist @@ -24,7 +24,7 @@ create_columns_node, create_summarized_columns_node, visualize_graph) LOGGER = logging.getLogger(__name__) -SINGLETABLEMETADATA_LOGGER = get_logger('SingleTableMetadata') +SINGLETABLEMETADATA_LOGGER = get_sdv_logger('SingleTableMetadata') class SingleTableMetadata: diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index e0e03b0b9..d5c40a874 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -16,10 +16,10 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError -from sdv.logging import disable_single_table_logger, get_logger +from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer -SYNTHESIZER_LOGGER = get_logger('MultiTableSynthesizer') +SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') class BaseMultiTableSynthesizer: @@ -667,9 +667,7 @@ def save(self, filepath): filepath (str): Path where the instance will be serialized. """ - with open(filepath, 'wb') as output: - cloudpickle.dump(self, output) - + synthesizer_id = getattr(self, '_synthesizer_id', None) SYNTHESIZER_LOGGER.info( '\nSave:\n' ' Timestamp: %s\n' @@ -677,8 +675,10 @@ def save(self, filepath): ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, - self._synthesizer_id, + synthesizer_id ) + with open(filepath, 'wb') as output: + cloudpickle.dump(self, output) @classmethod def load(cls, filepath): diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index a70e119e1..909089f75 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -24,11 +24,11 @@ from sdv.constraints.errors import AggregateConstraintsError from sdv.data_processing.data_processor import DataProcessor from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError -from sdv.logging.utils import get_logger +from sdv.logging.utils import get_sdv_logger from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path LOGGER = logging.getLogger(__name__) -SYNTHESIZER_LOGGER = get_logger('SingleTableSynthesizer') +SYNTHESIZER_LOGGER = get_sdv_logger('SingleTableSynthesizer') COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 @@ -462,6 +462,7 @@ def save(self, filepath): filepath (str): Path where the synthesizer instance will be serialized. """ + synthesizer_id = getattr(self, '_synthesizer_id', None) SYNTHESIZER_LOGGER.info( '\nSave:\n' ' Timestamp: %s\n' @@ -469,7 +470,7 @@ def save(self, filepath): ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, - self._synthesizer_id, + synthesizer_id ) with open(filepath, 'wb') as output: cloudpickle.dump(self, output) From b142cf66d84adea8e9daac7b27876820b8543a7c Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 19 Apr 2024 19:10:23 +0200 Subject: [PATCH 08/19] Add unit tests for the logging --- sdv/logging/utils.py | 6 +- sdv/metadata/multi_table.py | 5 +- sdv/metadata/single_table.py | 2 +- sdv/multi_table/base.py | 22 ++--- sdv/single_table/base.py | 20 ++-- tests/unit/logging/__init__.py | 0 tests/unit/logging/test_utils.py | 82 ++++++++++++++++ tests/unit/metadata/test_multi_table.py | 18 +++- tests/unit/metadata/test_single_table.py | 20 +++- tests/unit/multi_table/test_base.py | 109 ++++++++++++++++++---- tests/unit/single_table/test_base.py | 114 +++++++++++++++++++---- tests/utils.py | 17 ++++ 12 files changed, 343 insertions(+), 72 deletions(-) create mode 100644 tests/unit/logging/__init__.py create mode 100644 tests/unit/logging/test_utils.py diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 09ad89bec..b9669054d 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -36,9 +36,7 @@ def disable_single_table_logger(): # Logging without ``SingleTableSynthesizer`` single_table_logger = logging.getLogger('SingleTableSynthesizer') handlers = single_table_logger.handlers - for handler in handlers: - single_table_logger.removeHandler(handler) - + single_table_logger.handlers = [] try: yield finally: @@ -84,7 +82,7 @@ def get_sdv_logger(logger_name): file_handler.setLevel(log_level) file_handler.setFormatter(formatter) logger.addHandler(file_handler) - elif handler in ('consoleHandler', 'StreamingHandler'): + elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'): ch = logging.StreamHandler() ch.setLevel(log_level) ch.setFormatter(formatter) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 7d561a12b..1eff8efa8 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1051,10 +1051,11 @@ def save_to_json(self, filepath): '\nMetadata Save:\n' ' Timestamp: %s\n' ' Statistics about the metadata:\n' - ' Total number of tables: %s', - ' Total number of columns: %s' + ' Total number of tables: %s\n' + ' Total number of columns: %s\n' ' Total number of relationships: %s', datetime.datetime.now(), + len(self.tables), total_columns, len(self.relationships) ) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 4714f8887..4f8b1db94 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1212,7 +1212,7 @@ def save_to_json(self, filepath): '\nMetadata Save:\n' ' Timestamp: %s\n' ' Statistics about the metadata:\n' - ' Total number of tables: 1', + ' Total number of tables: 1' ' Total number of columns: %s' ' Total number of relationships: 0', datetime.now(), diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index d5c40a874..00efe700e 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -391,13 +391,13 @@ def fit_processed_data(self, processed_data): total_columns += len(table.columns) SYNTHESIZER_LOGGER.info( - '\nFit processed data\n' + '\nFit processed data:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' + ' Statistics of the fit processed data:\n' ' Total number of tables: %s\n' - ' Table number of rows: %s\n' - ' Table number of columns: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, @@ -431,13 +431,13 @@ def fit(self, data): total_columns += len(table.columns) SYNTHESIZER_LOGGER.info( - '\nFit\n' + '\nFit:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' ' Statistics of the fit data:\n' ' Total number of tables: %s\n' - ' Table number of rows: %s\n' - ' Table number of columns: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, @@ -491,10 +491,10 @@ def sample(self, scale=1.0): '\nSample:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' + ' Statistics of the sample size:\n' ' Total number of tables: %s\n' - ' Table number of rows: %s\n' - ' Table number of columns: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, @@ -701,7 +701,7 @@ def load(cls, filepath): synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) SYNTHESIZER_LOGGER.info( - '\nLoad\n' + '\nLoad:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' ' Synthesizer id: %s', diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 909089f75..5b808eb8c 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -402,18 +402,18 @@ def fit_processed_data(self, processed_data): The transformed data used to fit the model to. """ SYNTHESIZER_LOGGER.info( - '\nFit processed data\n' + '\nFit processed data:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' + ' Statistics of the fit processed data:\n' ' Total number of tables: 1\n' - ' Table number of rows: %s\n' - ' Table number of columns: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, - len(processed_data.columns), len(processed_data), + len(processed_data.columns), self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) @@ -433,18 +433,18 @@ def fit(self, data): The raw data (before any transformations) to fit the model to. """ SYNTHESIZER_LOGGER.info( - '\nFit\n' + '\nFit:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' ' Statistics of the fit data:\n' ' Total number of tables: 1\n' - ' Table number of rows: %s\n' - ' Table number of columns: %s\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' ' Synthesizer id: %s', datetime.datetime.now(), self.__class__.__name__, - len(data.columns), len(data), + len(data.columns), self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) @@ -496,7 +496,7 @@ def load(cls, filepath): synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) SYNTHESIZER_LOGGER.info( - '\nLoad\n' + '\nLoad:\n' ' Timestamp: %s\n' ' Synthesizer class name: %s\n' ' Synthesizer id: %s', diff --git a/tests/unit/logging/__init__.py b/tests/unit/logging/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/logging/test_utils.py b/tests/unit/logging/test_utils.py new file mode 100644 index 000000000..1a7539cdb --- /dev/null +++ b/tests/unit/logging/test_utils.py @@ -0,0 +1,82 @@ +"""Test ``SDV`` logging utilities.""" +import logging +from unittest.mock import Mock, mock_open, patch + +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config + + +def test_get_sdv_logger_config(): + """Test the ``get_sdv_logger_config``. + + Test that a ``yaml_content`` is being converted to ``dictionary`` and is returned + by the ``get_sdv_logger_config``. + """ + yaml_content = """ + loggers: + test_logger: + level: DEBUG + handlers: + class: logging.StreamHandler + """ + # Run + with patch('builtins.open', mock_open(read_data=yaml_content)): + # Test if the function returns a dictionary + logger_conf = get_sdv_logger_config() + + # Assert + assert isinstance(logger_conf, dict) + assert logger_conf == { + 'loggers': { + 'test_logger': { + 'level': 'DEBUG', + 'handlers': { + 'class': 'logging.StreamHandler' + } + } + } + } + + +@patch('sdv.logging.utils.logging.getLogger') +def test_disable_single_table_logger(mock_getlogger): + # Setup + mock_logger = Mock() + handler = Mock() + mock_logger.handlers = [handler] + mock_logger.removeHandler.side_effect = lambda x: mock_logger.handlers.pop(0) + mock_logger.addHandler.side_effect = lambda x: mock_logger.handlers.append(x) + mock_getlogger.return_value = mock_logger + + # Run + with disable_single_table_logger(): + assert len(mock_logger.handlers) == 0 + + # Assert + assert len(mock_logger.handlers) == 1 + + +@patch('sdv.logging.utils.logging.StreamHandler') +@patch('sdv.logging.utils.logging.getLogger') +@patch('sdv.logging.utils.get_sdv_logger_config') +def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamhandler): + # Setup + mock_logger_conf = { + 'loggers': { + 'test_logger': { + 'level': 'DEBUG', + 'handlers': { + 'class': 'logging.StreamHandler' + } + } + } + } + mock_get_sdv_logger_config.return_value = mock_logger_conf + mock_logger_instance = Mock() + mock_getlogger.return_value = mock_logger_instance + + # Run + get_sdv_logger('test_logger') + + # Assert + mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG) + mock_logger_instance.addHandler.assert_called_once() diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index b6089bdf0..c672d17dd 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1,6 +1,7 @@ """Test Multi Table Metadata.""" import json +import logging import re from collections import defaultdict from unittest.mock import Mock, call, patch @@ -12,7 +13,7 @@ from sdv.errors import InvalidDataError from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata, SingleTableMetadata -from tests.utils import get_multi_table_data, get_multi_table_metadata +from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata class TestMultiTableMetadata: @@ -2827,7 +2828,8 @@ def test_save_to_json_file_exists(self, mock_path): with pytest.raises(ValueError, match=error_msg): instance.save_to_json('filepath.json') - def test_save_to_json(self, tmp_path): + @patch('sdv.metadata.multi_table.datetime') + def test_save_to_json(self, mock_datetime, tmp_path, caplog): """Test the ``save_to_json`` method. Test that ``save_to_json`` stores a ``json`` file and dumps the instance dict into @@ -2844,16 +2846,26 @@ def test_save_to_json(self, tmp_path): # Setup instance = MultiTableMetadata() instance._reset_updated_flag = Mock() + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run / Assert file_name = tmp_path / 'multitable.json' - instance.save_to_json(file_name) + with catch_sdv_logs(caplog, logging.INFO, logger='MultiTableMetadata'): + instance.save_to_json(file_name) with open(file_name, 'rb') as multi_table_file: saved_metadata = json.load(multi_table_file) assert saved_metadata == instance.to_dict() instance._reset_updated_flag.assert_called_once() + assert caplog.messages[0] == ( + '\nMetadata Save:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Statistics about the metadata:\n' + ' Total number of tables: 0\n' + ' Total number of columns: 0\n' + ' Total number of relationships: 0' + ) def test__convert_relationships(self): """Test the ``_convert_relationships`` method. diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index 1b41414b0..d51b08ecc 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -1,6 +1,7 @@ """Test Single Table Metadata.""" import json +import logging import re import warnings from datetime import datetime @@ -13,6 +14,7 @@ from sdv.errors import InvalidDataError from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.single_table import SingleTableMetadata +from tests.utils import catch_sdv_logs class TestSingleTableMetadata: @@ -2835,7 +2837,8 @@ def test_save_to_json_file_exists(self, mock_path): with pytest.raises(ValueError, match=error_msg): instance.save_to_json('filepath.json') - def test_save_to_json(self, tmp_path): + @patch('sdv.metadata.single_table.datetime') + def test_save_to_json(self, mock_datetime, tmp_path, caplog): """Test the ``save_to_json`` method. Test that ``save_to_json`` stores a ``json`` file and dumps the instance dict into @@ -2850,12 +2853,23 @@ def test_save_to_json(self, tmp_path): - Creates a json representation of the instance. """ # Setup + mock_datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = SingleTableMetadata() - # Run / Assert + # Run file_name = tmp_path / 'singletable.json' - instance.save_to_json(file_name) + with catch_sdv_logs(caplog, logging.INFO, logger='SingleTableMetadata'): + instance.save_to_json(file_name) + # Assert + assert caplog.messages[0] == ( + '\nMetadata Save:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Statistics about the metadata:\n' + ' Total number of tables: 1' + ' Total number of columns: 0' + ' Total number of relationships: 0' + ) with open(file_name, 'rb') as single_table_file: saved_metadata = json.load(single_table_file) assert saved_metadata == instance.to_dict() diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 1e717cb68..50a5a2441 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1,3 +1,4 @@ +import logging import re import warnings from collections import defaultdict @@ -17,7 +18,7 @@ from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer from sdv.single_table.ctgan import CTGANSynthesizer -from tests.utils import get_multi_table_data, get_multi_table_metadata +from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata class TestBaseMultiTableSynthesizer: @@ -100,22 +101,26 @@ def test__print(self, mock_print): # Assert mock_print.assert_called_once_with('Fitting', end='') + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.generate_synthesizer_id') @patch('sdv.multi_table.base.BaseMultiTableSynthesizer._check_metadata_updated') - def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_id): + def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_id, + mock_datetime, caplog): """Test that when creating a new instance this sets the defaults. Test that the metadata object is being stored and also being validated. Afterwards, this calls the ``self._initialize_models`` which creates the initial instances of those. """ # Setup - synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + synthesizer_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() metadata.validate = Mock() # Run - instance = BaseMultiTableSynthesizer(metadata) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + instance = BaseMultiTableSynthesizer(metadata) # Assert assert instance.metadata == metadata @@ -127,6 +132,11 @@ def test___init__(self, mock_check_metadata_updated, mock_generate_synthesizer_i mock_check_metadata_updated.assert_called_once() mock_generate_synthesizer_id.assert_called_once_with(instance) assert instance._synthesizer_id == synthesizer_id + assert caplog.messages[0] == ( + '\nInstance:\n Timestamp: 2024-04-19 16:20:10.037183\n Synthesizer class name: ' + 'BaseMultiTableSynthesizer\n Synthesizer id: ' + 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test__init__column_relationship_warning(self): """Test that a warning is raised only once when the metadata has column relationships.""" @@ -818,16 +828,19 @@ def test_preprocess_warning(self, mock_warnings): "please refit the model using 'fit' or 'fit_processed_data'." ) - def test_fit_processed_data(self): + @patch('sdv.multi_table.base.datetime') + def test_fit_processed_data(self, mock_datetime, caplog): """Test that fit processed data calls ``_augment_tables`` and ``_model_tables``. Ensure that the ``fit_processed_data`` augments the tables and then models those using the ``_model_tables`` method. Then sets the state to fitted. """ # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, - _fitted_sdv_enterprise_version=None + _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) processed_data = { 'table1': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), @@ -835,12 +848,23 @@ def test_fit_processed_data(self): } # Run - BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + BaseMultiTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance._augment_tables.assert_called_once_with(processed_data) instance._model_tables.assert_called_once_with(instance._augment_tables.return_value) assert instance._fitted + assert caplog.messages[0] == ( + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 6\n' + ' Total number of columns: 4\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_processed_data_empty_table(self): """Test attributes are properly set when data is empty and that _fit is not called.""" @@ -890,13 +914,16 @@ def test_fit_processed_data_raises_version_error(self): instance.fit_processed_data.assert_not_called() instance._check_metadata_updated.assert_not_called() + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base._validate_foreign_keys_not_null') - def test_fit(self, mock_validate_foreign_keys_not_null): + def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, - _fitted_sdv_enterprise_version=None + _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) instance.metadata = Mock() data = { @@ -905,13 +932,24 @@ def test_fit(self, mock_validate_foreign_keys_not_null): } # Run - BaseMultiTableSynthesizer.fit(instance, data) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + BaseMultiTableSynthesizer.fit(instance, data) # Assert mock_validate_foreign_keys_not_null.assert_called_once_with(instance.metadata, data) instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() + assert caplog.messages[0] == ( + '\nFit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 6\n' + ' Total number of columns: 4\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_raises_version_error(self): """Test that fit will raise a ``VersionError`` if the current version is bigger.""" @@ -1000,9 +1038,11 @@ def test_sample_validate_input(self): with pytest.raises(SynthesizerInputError, match=msg): instance.sample(scale=scale) - def test_sample(self): + @patch('sdv.multi_table.base.datetime') + def test_sample(self, mock_datetime, caplog): """Test that ``sample`` calls the ``_sample`` with the given arguments.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) data = { @@ -1011,11 +1051,25 @@ def test_sample(self): } instance._sample = Mock(return_value=data) + synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + instance._synthesizer_id = synth_id + # Run - instance.sample(scale=1.5) + with catch_sdv_logs(caplog, logging.INFO, logger='MultiTableSynthesizer'): + instance.sample(scale=1.5) # Assert instance._sample.assert_called_once_with(scale=1.5) + assert caplog.messages[0] == ( + '\nSample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: BaseMultiTableSynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 6\n' + ' Total number of columns: 4\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_get_learned_distributions_raises_an_unfitted_error(self): """Test that ``get_learned_distributions`` raises an error when model is not fitted.""" @@ -1404,19 +1458,31 @@ def test_get_info_with_enterprise(self, mock_version): 'fitted_sdv_enterprise_version': '1.1.0' } + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.cloudpickle') - def test_save(self, cloudpickle_mock, tmp_path): + def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup - synthesizer = Mock() + synthesizer = Mock( + _synthesizer_id='BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run filepath = tmp_path / 'output.pkl' - BaseMultiTableSynthesizer.save(synthesizer, filepath) + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + BaseMultiTableSynthesizer.save(synthesizer, filepath) # Assert cloudpickle_mock.dump.assert_called_once_with(synthesizer, ANY) + assert caplog.messages[0] == ( + '\nSave:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + @patch('sdv.multi_table.base.datetime') @patch('sdv.multi_table.base.generate_synthesizer_id') @patch('sdv.multi_table.base.check_synthesizer_version') @patch('sdv.multi_table.base.check_sdv_versions_and_warn') @@ -1424,16 +1490,17 @@ def test_save(self, cloudpickle_mock, tmp_path): @patch('builtins.open', new_callable=mock_open) def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_warn, mock_check_synthesizer_version, - mock_generate_synthesizer_id): + mock_generate_synthesizer_id, mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup - synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + synthesizer_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) cloudpickle_mock.load.return_value = synthesizer_mock # Run - loaded_instance = BaseMultiTableSynthesizer.load('synth.pkl') + with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): + loaded_instance = BaseMultiTableSynthesizer.load('synth.pkl') # Assert mock_file.assert_called_once_with('synth.pkl', 'rb') @@ -1443,3 +1510,9 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_synthesizer_version.assert_called_once_with(synthesizer_mock) assert loaded_instance._synthesizer_id == synthesizer_id mock_generate_synthesizer_id.assert_called_once_with(synthesizer_mock) + assert caplog.messages[0] == ( + '\nLoad:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index b4c106b80..9aae341a9 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1,3 +1,4 @@ +import logging import re from datetime import date, datetime from unittest.mock import ANY, MagicMock, Mock, call, mock_open, patch @@ -17,6 +18,7 @@ from sdv.single_table import ( CopulaGANSynthesizer, CTGANSynthesizer, GaussianCopulaSynthesizer, TVAESynthesizer) from sdv.single_table.base import COND_IDX, BaseSingleTableSynthesizer +from tests.utils import catch_sdv_logs class TestBaseSingleTableSynthesizer: @@ -59,19 +61,22 @@ def test__check_metadata_updated(self): # Assert instance.metadata._updated = False + @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.generate_synthesizer_id') @patch('sdv.single_table.base.DataProcessor') @patch('sdv.single_table.base.BaseSingleTableSynthesizer._check_metadata_updated') def test___init__(self, mock_check_metadata_updated, mock_data_processor, - mock_generate_synthesizer_id): + mock_generate_synthesizer_id, mock_datetime, caplog): """Test instantiating with default values.""" # Setup metadata = Mock() synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run - instance = BaseSingleTableSynthesizer(metadata) + with catch_sdv_logs(caplog, logging.INFO, logger='SingleTableSynthesizer'): + instance = BaseSingleTableSynthesizer(metadata) # Assert assert instance.enforce_min_max_values is True @@ -89,6 +94,11 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor, metadata.validate.assert_called_once_with() mock_check_metadata_updated.assert_called_once() mock_generate_synthesizer_id.assert_called_once_with(instance) + assert caplog.messages[0] == ( + '\nInstance:\n Timestamp: 2024-04-19 16:20:10.037183\n Synthesizer class name: ' + 'BaseSingleTableSynthesizer\n Synthesizer id: ' + 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) @patch('sdv.single_table.base.DataProcessor') def test___init__custom(self, mock_data_processor): @@ -341,20 +351,34 @@ def test__fit(self, mock_data_processor): with pytest.raises(NotImplementedError, match=''): instance._fit(data) - def test_fit_processed_data(self): + @patch('sdv.single_table.base.datetime') + def test_fit_processed_data(self, mock_datetime, caplog): """Test that ``fit_processed_data`` calls the ``_fit``.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) # Run - BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) # Assert instance._fit.assert_called_once_with(processed_data) + assert caplog.messages[0] == ( + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 1\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_processed_data_raises_version_error(self): """Test that ``fit`` raises ``VersionError`` @@ -380,23 +404,27 @@ def test_fit_processed_data_raises_version_error(self): with pytest.raises(VersionError, match=error_msg): BaseSingleTableSynthesizer.fit_processed_data(instance, processed_data) - def test_fit(self): + @patch('sdv.single_table.base.datetime') + def test_fit(self, mock_datetime, caplog): """Test that ``fit`` calls ``preprocess`` and the ``fit_processed_data``. When fitting, the synthsizer has to ``preprocess`` the data and with the output of this method, call the ``fit_processed_data`` """ # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True instance._fitted = True # Run - BaseSingleTableSynthesizer.fit(instance, data) + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + BaseSingleTableSynthesizer.fit(instance, data) # Assert assert instance._random_state_set is False @@ -404,6 +432,16 @@ def test_fit(self): instance._preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance._preprocess.return_value) instance._check_metadata_updated.assert_called_once() + assert caplog.messages[0] == ( + '\nFit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 2\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_fit_raises_version_error(self): """Test that ``fit`` raises ``VersionError`` @@ -1361,25 +1399,30 @@ def test__sample_with_progress_bar_removing_temp_file( mock_os.remove.assert_called_once_with('.sample.csv.temp') mock_os.path.exists.assert_called_once_with('.sample.csv.temp') - def test_sample(self): + @patch('sdv.single_table.base.datetime') + def test_sample(self, mock_datetime, caplog): """Test that we use ``_sample_with_progress_bar`` in this method.""" # Setup + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' num_rows = 10 max_tries_per_batch = 50 batch_size = 5 output_file_path = 'temp.csv' - instance = Mock() + instance = Mock( + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) instance.get_metadata.return_value._constraints = False instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) # Run - result = BaseSingleTableSynthesizer.sample( - instance, - num_rows, - max_tries_per_batch, - batch_size, - output_file_path, - ) + with catch_sdv_logs(caplog, logging.INFO, logger='SingleTableSynthesizer'): + result = BaseSingleTableSynthesizer.sample( + instance, + num_rows, + max_tries_per_batch, + batch_size, + output_file_path, + ) # Assert instance._sample_with_progress_bar.assert_called_once_with( @@ -1390,6 +1433,16 @@ def test_sample(self): show_progress_bar=True ) pd.testing.assert_frame_equal(result, pd.DataFrame({'col': [1, 2, 3]})) + assert caplog.messages[0] == ( + '\nSample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 1\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test__validate_conditions_unseen_columns(self): """Test that conditions are within the ``data_processor`` fields.""" @@ -1742,35 +1795,50 @@ def test__validate_known_columns_a_few_nans(self): with pytest.warns(UserWarning, match=warn_msg): synthesizer._validate_known_columns(conditions) + @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.cloudpickle') - def test_save(self, cloudpickle_mock, tmp_path): + def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup - synthesizer = Mock() + synthesizer = Mock( + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' # Run filepath = tmp_path / 'output.pkl' - BaseSingleTableSynthesizer.save(synthesizer, filepath) + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + BaseSingleTableSynthesizer.save(synthesizer, filepath) # Assert cloudpickle_mock.dump.assert_called_once_with(synthesizer, ANY) + assert caplog.messages[0] == ( + '\nSave:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) + @patch('sdv.single_table.base.datetime') @patch('sdv.single_table.base.generate_synthesizer_id') @patch('sdv.single_table.base.check_synthesizer_version') @patch('sdv.single_table.base.check_sdv_versions_and_warn') @patch('sdv.single_table.base.cloudpickle') @patch('builtins.open', new_callable=mock_open) def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_warn, - mock_check_synthesizer_version, mock_generate_synthesizer_id): + mock_check_synthesizer_version, mock_generate_synthesizer_id, + mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id cloudpickle_mock.load.return_value = synthesizer_mock # Run - loaded_instance = BaseSingleTableSynthesizer.load('synth.pkl') + with catch_sdv_logs(caplog, logging.INFO, 'SingleTableSynthesizer'): + loaded_instance = BaseSingleTableSynthesizer.load('synth.pkl') # Assert mock_file.assert_called_once_with('synth.pkl', 'rb') @@ -1780,6 +1848,12 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war assert loaded_instance._synthesizer_id == synthesizer_id mock_check_synthesizer_version.assert_called_once_with(synthesizer_mock) mock_generate_synthesizer_id.assert_called_once_with(synthesizer_mock) + assert caplog.messages[0] == ( + '\nLoad:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: Mock\n' + ' Synthesizer id: BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + ) def test_load_custom_constraint_classes(self): """Test that ``load_custom_constraint_classes`` calls the ``DataProcessor``'s method.""" diff --git a/tests/utils.py b/tests/utils.py index a1d819eb7..bf13b9f02 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,9 @@ """Utils for testing.""" +import contextlib + import pandas as pd +from sdv.logging import get_sdv_logger from sdv.metadata.multi_table import MultiTableMetadata @@ -99,3 +102,17 @@ def get_multi_table_data(): } return data + + +@contextlib.contextmanager +def catch_sdv_logs(caplog, level, logger): + """Context manager to capture logs from an SDV logger.""" + logger = get_sdv_logger(logger) + orig_level = logger.level + logger.setLevel(level) + logger.addHandler(caplog.handler) + try: + yield + finally: + logger.setLevel(orig_level) + logger.removeHandler(caplog.handler) From de1c1db2d7d636679a3938831ebca1875f5d5d2f Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Fri, 19 Apr 2024 19:13:36 +0200 Subject: [PATCH 09/19] Bump and fix pytest --- tests/integration/datasets/test_local.py | 2 +- tests/integration/single_table/test_constraints.py | 4 ++-- tests/integration/utils/test_poc.py | 4 ++-- tests/unit/test___init__.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/integration/datasets/test_local.py b/tests/integration/datasets/test_local.py index ee1438ff5..135375372 100644 --- a/tests/integration/datasets/test_local.py +++ b/tests/integration/datasets/test_local.py @@ -4,7 +4,7 @@ from sdv.datasets.local import save_csvs -@pytest.fixture() +@pytest.fixture def data(): parent = pd.DataFrame(data={ 'id': [0, 1, 2, 3, 4], diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 9061c67cd..151fb8d83 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -30,12 +30,12 @@ def _isinstance_side_effect(*args, **kwargs): ) -@pytest.fixture() +@pytest.fixture def demo_data(): return DEMO_DATA -@pytest.fixture() +@pytest.fixture def demo_metadata(): return DEMO_METADATA diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index ac4d4e50d..e8189b6ea 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -13,7 +13,7 @@ from sdv.utils.poc import drop_unknown_references, simplify_schema -@pytest.fixture() +@pytest.fixture def metadata(): return MultiTableMetadata.load_from_dict( { @@ -45,7 +45,7 @@ def metadata(): ) -@pytest.fixture() +@pytest.fixture def data(): parent = pd.DataFrame(data={ 'id': [0, 1, 2, 3, 4], diff --git a/tests/unit/test___init__.py b/tests/unit/test___init__.py index b00144c23..e94f3b214 100644 --- a/tests/unit/test___init__.py +++ b/tests/unit/test___init__.py @@ -8,7 +8,7 @@ from sdv import _find_addons -@pytest.fixture() +@pytest.fixture def mock_sdv(): sdv_module = sys.modules['sdv'] sdv_mock = Mock() From 667968a46a79929a493b0487143deb2451361084 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Sun, 21 Apr 2024 19:13:18 +0200 Subject: [PATCH 10/19] Fix unit test --- tests/unit/multi_table/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 50a5a2441..0cad058b0 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1494,6 +1494,7 @@ def test_load(self, mock_file, cloudpickle_mock, """Test that the ``load`` method loads a stored synthesizer.""" # Setup synthesizer_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' mock_generate_synthesizer_id.return_value = synthesizer_id synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) cloudpickle_mock.load.return_value = synthesizer_mock From c8a8a2cb0f3d40a478f593738e71ddc754e9bcb4 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 12:27:31 +0200 Subject: [PATCH 11/19] Add table_name and conditional logging --- sdv/logging/__init__.py | 3 +- sdv/logging/utils.py | 85 +++++------- sdv/multi_table/base.py | 27 ++-- sdv/single_table/base.py | 161 ++++++++++++---------- sdv/single_table/copulagan.py | 4 +- sdv/single_table/copulas.py | 4 +- sdv/single_table/ctgan.py | 9 +- tests/integration/multi_table/test_hma.py | 11 +- tests/unit/logging/test_utils.py | 23 +--- tests/unit/multi_table/test_base.py | 15 +- tests/unit/single_table/test_base.py | 25 ++-- tests/unit/single_table/test_copulagan.py | 1 + tests/unit/single_table/test_copulas.py | 3 +- tests/unit/single_table/test_ctgan.py | 2 + 14 files changed, 194 insertions(+), 179 deletions(-) diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index 436a1a442..1080bc9c1 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -1,9 +1,8 @@ """Module for configuring loggers within the SDV library.""" -from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config +from sdv.logging.utils import get_sdv_logger, get_sdv_logger_config __all__ = ( - 'disable_single_table_logger', 'get_sdv_logger', 'get_sdv_logger_config', ) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index b9669054d..80abcc2f9 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -1,11 +1,11 @@ """Utilities for configuring logging within the SDV library.""" -import contextlib import logging import logging.config from functools import lru_cache from pathlib import Path +import platformdirs import yaml @@ -16,34 +16,16 @@ def get_sdv_logger_config(): logger_conf = yaml.safe_load(f) # Logfile to be in this same directory + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + store_path.mkdir(parents=True, exist_ok=True) for logger in logger_conf.get('loggers', {}).values(): handler = logger.get('handlers', {}) if handler.get('filename') == 'sdv_logs.log': - handler['filename'] = logging_path / handler['filename'] + handler['filename'] = store_path / handler['filename'] return logger_conf -@contextlib.contextmanager -def disable_single_table_logger(): - """Temporarily disables logging for the single table synthesizers. - - This context manager temporarily removes all handlers associated with - the ``SingleTableSynthesizer`` logger, disabling logging for that module - within the current context. After the context exits, the - removed handlers are restored to the logger. - """ - # Logging without ``SingleTableSynthesizer`` - single_table_logger = logging.getLogger('SingleTableSynthesizer') - handlers = single_table_logger.handlers - single_table_logger.handlers = [] - try: - yield - finally: - for handler in handlers: - single_table_logger.addHandler(handler) - - @lru_cache() def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. @@ -62,30 +44,35 @@ def get_sdv_logger(logger_name): and the specific settings for the given logger name. """ logger_conf = get_sdv_logger_config() - logger = logging.getLogger(logger_name) - if logger_name in logger_conf.get('loggers'): - formatter = None - config = logger_conf.get('loggers').get(logger_name) - log_level = getattr(logging, config.get('level', 'INFO')) - if config.get('format'): - formatter = logging.Formatter(config.get('format')) - - logger.setLevel(log_level) - logger.propagate = config.get('propagate', False) - handler = config.get('handlers') - handlers = handler.get('class') - handlers = [handlers] if isinstance(handlers, str) else handlers - for handler_class in handlers: - if handler_class == 'logging.FileHandler': - logfile = handler.get('filename') - file_handler = logging.FileHandler(logfile) - file_handler.setLevel(log_level) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'): - ch = logging.StreamHandler() - ch.setLevel(log_level) - ch.setFormatter(formatter) - logger.addHandler(ch) - - return logger + if logger_conf.get('log_registry') is None: + # Return a logger without any extra settings and avoid writing into files or other streams + return logging.getLogger(logger_name) + + if logger_conf.get('log_registry') == 'local': + logger = logging.getLogger(logger_name) + if logger_name in logger_conf.get('loggers'): + formatter = None + config = logger_conf.get('loggers').get(logger_name) + log_level = getattr(logging, config.get('level', 'INFO')) + if config.get('format'): + formatter = logging.Formatter(config.get('format')) + + logger.setLevel(log_level) + logger.propagate = config.get('propagate', False) + handler = config.get('handlers') + handlers = handler.get('class') + handlers = [handlers] if isinstance(handlers, str) else handlers + for handler_class in handlers: + if handler_class == 'logging.FileHandler': + logfile = handler.get('filename') + file_handler = logging.FileHandler(logfile) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'): + ch = logging.StreamHandler() + ch.setLevel(log_level) + ch.setFormatter(formatter) + logger.addHandler(ch) + + return logger diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 00efe700e..8095cfad9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -16,7 +16,7 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError -from sdv.logging import disable_single_table_logger, get_sdv_logger +from sdv.logging import get_sdv_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') @@ -59,14 +59,14 @@ def _set_temp_numpy_seed(self): np.random.set_state(initial_state) def _initialize_models(self): - with disable_single_table_logger(): - for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) - self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, - locales=self.locales, - **synthesizer_parameters - ) + for table_name, table_metadata in self.metadata.tables.items(): + synthesizer_parameters = self._table_parameters.get(table_name, {}) + self._table_synthesizers[table_name] = self._synthesizer( + metadata=table_metadata, + locales=self.locales, + table_name=table_name, + **synthesizer_parameters + ) def _get_pbar_args(self, **kwargs): """Return a dictionary with the updated keyword args for a progress bar.""" @@ -199,6 +199,8 @@ def set_table_parameters(self, table_name, table_parameters): A dictionary with the parameters as keys and the values to be used to instantiate the table's synthesizer. """ + # Ensure that we set the name of the table no matter what + table_parameters.update({'table_name': table_name}) self._table_synthesizers[table_name] = self._synthesizer( metadata=self.metadata.tables[table_name], **table_parameters @@ -407,9 +409,8 @@ def fit_processed_data(self, processed_data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - with disable_single_table_logger(): - augmented_data = self._augment_tables(processed_data) - self._model_tables(augmented_data) + augmented_data = self._augment_tables(processed_data) + self._model_tables(augmented_data) self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') @@ -478,7 +479,7 @@ def sample(self, scale=1.0): raise SynthesizerInputError( f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") - with self._set_temp_numpy_seed(), disable_single_table_logger(): + with self._set_temp_numpy_seed(): sampled_data = self._sample(scale=scale) total_rows = 0 diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 5b808eb8c..393c010e0 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -88,7 +88,7 @@ def _check_metadata_updated(self): ) def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US']): + locales=['en_US'], table_name=None): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata self.metadata.validate() @@ -96,11 +96,13 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self.enforce_min_max_values = enforce_min_max_values self.enforce_rounding = enforce_rounding self.locales = locales + self.table_name = table_name self._data_processor = DataProcessor( metadata=self.metadata, enforce_rounding=self.enforce_rounding, enforce_min_max_values=self.enforce_min_max_values, - locales=self.locales + locales=self.locales, + table_name=self.table_name ) self._fitted = False self._random_state_set = False @@ -110,15 +112,16 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - SYNTHESIZER_LOGGER.info( - '\nInstance:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - self._synthesizer_id - ) + if self.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nInstance:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id + ) def set_address_columns(self, column_names, anonymization_level='full'): """Set the address multi-column transformer.""" @@ -401,21 +404,23 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - SYNTHESIZER_LOGGER.info( - '\nFit processed data:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(processed_data), - len(processed_data.columns), - self._synthesizer_id, - ) + if self.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nFit processed data:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(processed_data), + len(processed_data.columns), + self._synthesizer_id, + ) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: self._fit(processed_data) @@ -432,21 +437,23 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ - SYNTHESIZER_LOGGER.info( - '\nFit:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(data), - len(data.columns), - self._synthesizer_id, - ) + if self.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nFit:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(data), + len(data.columns), + self._synthesizer_id, + ) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False @@ -463,15 +470,17 @@ def save(self, filepath): Path where the synthesizer instance will be serialized. """ synthesizer_id = getattr(self, '_synthesizer_id', None) - SYNTHESIZER_LOGGER.info( - '\nSave:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - synthesizer_id - ) + if self.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nSave:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + synthesizer_id + ) + with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -495,15 +504,17 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - SYNTHESIZER_LOGGER.info( - '\nLoad:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - synthesizer.__class__.__name__, - synthesizer._synthesizer_id, - ) + if synthesizer.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nLoad:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) + return synthesizer @@ -879,21 +890,23 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file output_file_path, show_progress_bar=show_progress_bar ) - SYNTHESIZER_LOGGER.info( - '\nSample:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - sample_timestamp, - self.__class__.__name__, - len(sampled_data), - len(sampled_data.columns), - self._synthesizer_id, - ) + + if self.table_name is None: + SYNTHESIZER_LOGGER.info( + '\nSample:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + sample_timestamp, + self.__class__.__name__, + len(sampled_data), + len(sampled_data.columns), + self._synthesizer_id, + ) return sampled_data diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index c9309b45c..63b22d22b 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -121,7 +121,8 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, numerical_distributions=None, default_distribution=None): + pac=10, cuda=True, numerical_distributions=None, default_distribution=None, + table_name=None): super().__init__( metadata, @@ -142,6 +143,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, epochs=epochs, pac=pac, cuda=cuda, + table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 4fc213949..c19b7536d 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -91,12 +91,14 @@ def get_distribution_class(cls, distribution): return cls._DISTRIBUTIONS[distribution] def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], numerical_distributions=None, default_distribution=None): + locales=['en_US'], numerical_distributions=None, default_distribution=None, + table_name=None): super().__init__( metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, locales=locales, + table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) self.numerical_distributions = numerical_distributions or {} diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index c6c5d3d0c..d59c3fca0 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -155,13 +155,14 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True): + pac=10, cuda=True, table_name=None): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, - locales=locales + locales=locales, + table_name=table_name ) self.embedding_dim = embedding_dim @@ -338,12 +339,14 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True): + l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True, + table_name=None): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, + table_name=table_name ) self.embedding_dim = embedding_dim self.compress_dims = compress_dims diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4cc189975..bf5405d04 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -148,7 +148,8 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, + 'table_name': 'characters' } families_params = hmasynthesizer.get_table_parameters('families') assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -157,7 +158,8 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, + 'table_name': 'families' } char_families_params = hmasynthesizer.get_table_parameters('character_families') assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -166,7 +168,8 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {} + 'numerical_distributions': {}, + 'table_name': 'character_families' } assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' @@ -551,7 +554,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): custom_synthesizer.set_table_parameters( table_name='hotels', table_parameters={ - 'default_distribution': 'truncnorm' + 'default_distribution': 'truncnorm', } ) diff --git a/tests/unit/logging/test_utils.py b/tests/unit/logging/test_utils.py index 1a7539cdb..2235779f6 100644 --- a/tests/unit/logging/test_utils.py +++ b/tests/unit/logging/test_utils.py @@ -2,7 +2,7 @@ import logging from unittest.mock import Mock, mock_open, patch -from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config +from sdv.logging.utils import get_sdv_logger, get_sdv_logger_config def test_get_sdv_logger_config(): @@ -12,6 +12,7 @@ def test_get_sdv_logger_config(): by the ``get_sdv_logger_config``. """ yaml_content = """ + log_registry: 'local' loggers: test_logger: level: DEBUG @@ -26,6 +27,7 @@ def test_get_sdv_logger_config(): # Assert assert isinstance(logger_conf, dict) assert logger_conf == { + 'log_registry': 'local', 'loggers': { 'test_logger': { 'level': 'DEBUG', @@ -37,30 +39,13 @@ def test_get_sdv_logger_config(): } -@patch('sdv.logging.utils.logging.getLogger') -def test_disable_single_table_logger(mock_getlogger): - # Setup - mock_logger = Mock() - handler = Mock() - mock_logger.handlers = [handler] - mock_logger.removeHandler.side_effect = lambda x: mock_logger.handlers.pop(0) - mock_logger.addHandler.side_effect = lambda x: mock_logger.handlers.append(x) - mock_getlogger.return_value = mock_logger - - # Run - with disable_single_table_logger(): - assert len(mock_logger.handlers) == 0 - - # Assert - assert len(mock_logger.handlers) == 1 - - @patch('sdv.logging.utils.logging.StreamHandler') @patch('sdv.logging.utils.logging.getLogger') @patch('sdv.logging.utils.get_sdv_logger_config') def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamhandler): # Setup mock_logger_conf = { + 'log_registry': 'local', 'loggers': { 'test_logger': { 'level': 'DEBUG', diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 0cad058b0..b4de98e6a 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -52,9 +52,10 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales), - call(metadata=instance.metadata.tables['oseba'], locales=locales), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales) + locales=locales, table_name='nesreca'), + call(metadata=instance.metadata.tables['oseba'], locales=locales, table_name='oseba'), + call(metadata=instance.metadata.tables['upravna_enota'], locales=locales, + table_name='upravna_enota') ]) def test__get_pbar_args(self): @@ -279,6 +280,7 @@ def test_get_table_parameters_empty(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], + 'table_name': 'oseba', 'numerical_distributions': {} } } @@ -299,6 +301,7 @@ def test_get_table_parameters_has_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], + 'table_name': 'oseba', 'numerical_distributions': {} } @@ -330,13 +333,17 @@ def test_set_table_parameters(self): # Assert table_parameters = instance.get_table_parameters('oseba') - assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'} + assert instance._table_parameters['oseba'] == { + 'default_distribution': 'gamma', + 'table_name': 'oseba' + } assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer' assert table_parameters['synthesizer_parameters'] == { 'default_distribution': 'gamma', 'enforce_min_max_values': True, 'locales': ['en_US'], 'enforce_rounding': True, + 'table_name': 'oseba', 'numerical_distributions': {} } diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 9aae341a9..289cf7519 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -89,7 +89,8 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor, metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales + locales=instance.locales, + table_name=None ) metadata.validate.assert_called_once_with() mock_check_metadata_updated.assert_called_once() @@ -123,7 +124,8 @@ def test___init__custom(self, mock_data_processor): metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales + locales=instance.locales, + table_name=None ) metadata.validate.assert_called_once_with() @@ -182,7 +184,8 @@ def test_get_parameters(self, mock_data_processor): assert parameters == { 'enforce_min_max_values': False, 'enforce_rounding': False, - 'locales': 'en_CA' + 'locales': 'en_CA', + 'table_name': None } @patch('sdv.single_table.base.DataProcessor') @@ -359,7 +362,8 @@ def test_fit_processed_data(self, mock_datetime, caplog): instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, - _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) @@ -390,6 +394,7 @@ def test_fit_processed_data_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, + table_name=None ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True @@ -416,7 +421,8 @@ def test_fit(self, mock_datetime, caplog): instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, - _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None ) data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True @@ -453,6 +459,7 @@ def test_fit_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, + table_name=None ) data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True @@ -1409,7 +1416,8 @@ def test_sample(self, mock_datetime, caplog): batch_size = 5 output_file_path = 'temp.csv' instance = Mock( - _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None ) instance.get_metadata.return_value._constraints = False instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) @@ -1801,7 +1809,8 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup synthesizer = Mock( - _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + table_name=None ) mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' @@ -1830,7 +1839,7 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup - synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) + synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None, table_name=None) mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index a818dc85a..762e28f58 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -175,6 +175,7 @@ def test_get_params(self): 'cuda': True, 'numerical_distributions': {}, 'default_distribution': 'beta', + 'table_name': None } @patch('sdv.single_table.copulagan.rdt') diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 3c96028c3..02ec24b14 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -130,7 +130,8 @@ def test_get_parameters(self): 'enforce_rounding': True, 'locales': ['en_US'], 'numerical_distributions': {}, - 'default_distribution': 'beta' + 'default_distribution': 'beta', + 'table_name': None } @patch('sdv.single_table.copulas.LOGGER') diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 967823094..e18e27552 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -151,6 +151,7 @@ def test_get_parameters(self): 'epochs': 300, 'pac': 10, 'cuda': True, + 'table_name': None } def test__estimate_num_columns(self): @@ -426,6 +427,7 @@ def test_get_parameters(self): 'epochs': 300, 'loss_factor': 2, 'cuda': True, + 'table_name': None } @patch('sdv.single_table.ctgan.TVAE') From 33f0f6ca01eac83d4bfb62d7fd525ba1bf7b01a4 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 12:52:43 +0200 Subject: [PATCH 12/19] Fix table name propagation --- sdv/multi_table/hma.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 6ce1d37ab..8c91a0a63 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -314,6 +314,7 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc else: synthesizer = self._synthesizer( table_meta, + table_name=child_name, **self._table_parameters[child_name] ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) @@ -521,7 +522,11 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.tables[child_name] - synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name]) + synthesizer = self._synthesizer( + table_meta, + table_name=child_name, + **self._table_parameters[child_name] + ) synthesizer._set_parameters(parameters, default_parameters) synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor @@ -615,7 +620,11 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() - synthesizer = self._synthesizer(table_meta, **self._table_parameters[table_name]) + synthesizer = self._synthesizer( + table_meta, + table_name=table_name, + **self._table_parameters[table_name] + ) synthesizer._set_parameters(parameters) try: likelihoods[parent_id] = synthesizer._get_likelihood(table_rows) From b3a28c0311b27030618acd73b0a12c3b786df401 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 15:15:40 +0200 Subject: [PATCH 13/19] Fix tests --- pyproject.toml | 1 + sdv/multi_table/hma.py | 15 +++++++++------ tests/unit/multi_table/test_hma.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 901c68965..8e75e0b4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ 'deepecho>=0.6.0', 'rdt>=1.12.0', 'sdmetrics>=0.14.0', + 'platformdirs>=4.0' ] [project.urls] diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 8c91a0a63..1e9ce6c2a 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -312,10 +312,11 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: + synthesizer_parameters = self._table_parameters[child_name] + synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - table_name=child_name, - **self._table_parameters[child_name] + **synthesizer_parameters ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) row = synthesizer._get_parameters() @@ -522,10 +523,11 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.tables[child_name] + synthesizer_parameters = self._table_parameters[child_name] + synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - table_name=child_name, - **self._table_parameters[child_name] + **synthesizer_parameters ) synthesizer._set_parameters(parameters, default_parameters) synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor @@ -620,10 +622,11 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() + synthesizer_parameters = self._table_parameters[table_name] + synthesizer_parameters.update({'table_name': table_name}) synthesizer = self._synthesizer( table_meta, - table_name=table_name, - **self._table_parameters[table_name] + **synthesizer_parameters ) synthesizer._set_parameters(parameters) try: diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index c40e7b080..6bbf7ef6f 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -502,7 +502,7 @@ def test__recreate_child_synthesizer(self): # Assert assert synthesizer == instance._synthesizer.return_value assert synthesizer._data_processor == table_synthesizer._data_processor - instance._synthesizer.assert_called_once_with(table_meta, a=1) + instance._synthesizer.assert_called_once_with(table_meta, table_name='users', a=1) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value, {'colA': 'default_param', 'colB': 'default_param'} From d48aad6254ae4b377c932b5fd0ea714485fcef32 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 18:50:53 +0200 Subject: [PATCH 14/19] Add integration tests --- sdv/logging/utils.py | 3 - tests/integration/multi_table/test_hma.py | 74 +++++++++++++++++++ tests/integration/single_table/test_base.py | 78 +++++++++++++++++++++ 3 files changed, 152 insertions(+), 3 deletions(-) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 80abcc2f9..a7aae8e04 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -1,8 +1,6 @@ """Utilities for configuring logging within the SDV library.""" import logging -import logging.config -from functools import lru_cache from pathlib import Path import platformdirs @@ -26,7 +24,6 @@ def get_sdv_logger_config(): return logger_conf -@lru_cache() def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index bf5405d04..8c2a97ac6 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2,9 +2,12 @@ import importlib.metadata import re import warnings +from pathlib import Path +from unittest.mock import patch import numpy as np import pandas as pd +import platformdirs import pytest from faker import Faker from rdt.transformers import FloatFormatter @@ -1667,3 +1670,74 @@ def test_hma_relationship_validity(): # Assert assert report.get_details('Relationship Validity')['Score'].mean() == 1.0 + + +@patch('sdv.multi_table.base.generate_synthesizer_id') +@patch('sdv.multi_table.base.datetime') +def test_synthesizer_logger(mock_datetime, mock_generate_id): + """Test that the synthesizer logger logs the expected messages.""" + # Setup + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + file_name = 'sdv_logs.log' + + synth_id = 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_generate_id.return_value = synth_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' + data, metadata = download_demo('multi_table', 'fake_hotels') + + # Run + instance = HMASynthesizer(metadata) + + # Assert + with open(store_path / file_name) as f: + instance_lines = f.readlines()[-4:] + + assert ''.join(instance_lines) == ( + 'Instance:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.fit(data) + + # Assert + with open(store_path / file_name) as f: + fit_lines = f.readlines()[-17:] + + assert ''.join(fit_lines) == ( + 'Fit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 15\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 11\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.sample(1) + with open(store_path / file_name) as f: + sample_lines = f.readlines()[-8:] + + # Assert + assert ''.join(sample_lines) == ( + 'Sample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 15\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 73b25c97c..8c7ea2601 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -2,10 +2,12 @@ import importlib.metadata import re import warnings +from pathlib import Path from unittest.mock import patch import numpy as np import pandas as pd +import platformdirs import pytest from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder @@ -777,3 +779,79 @@ def test_fit_raises_version_error(): ) with pytest.raises(VersionError, match=expected_message): instance.fit(data) + + +@patch('sdv.single_table.base.generate_synthesizer_id') +@patch('sdv.single_table.base.datetime') +def test_synthesizer_logger(mock_datetime, mock_generate_id): + """Test that the synthesizer logger logs the expected messages.""" + # Setup + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + file_name = 'sdv_logs.log' + + synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_generate_id.return_value = synth_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' + data = pd.DataFrame({ + 'col 1': [1, 2, 3], + 'col 2': [4, 5, 6], + 'col 3': ['a', 'b', 'c'], + }) + metadata = SingleTableMetadata() + metadata.detect_from_dataframe(data) + + # Run + instance = GaussianCopulaSynthesizer(metadata) + + # Assert + with open(store_path / file_name) as f: + instance_lines = f.readlines()[-4:] + + assert ''.join(instance_lines) == ( + 'Instance:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.fit(data) + + # Assert + with open(store_path / file_name) as f: + fit_lines = f.readlines()[-17:] + + assert ''.join(fit_lines) == ( + 'Fit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.sample(100) + with open(store_path / file_name) as f: + sample_lines = f.readlines()[-8:] + + assert ''.join(sample_lines) == ( + 'Sample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 100\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) From 04431cdf986d27e5a72f16cd95c174843741cae8 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 18:55:36 +0200 Subject: [PATCH 15/19] Add back lru cache --- sdv/logging/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index a7aae8e04..89469c21a 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -1,6 +1,7 @@ """Utilities for configuring logging within the SDV library.""" import logging +from functools import lru_cache from pathlib import Path import platformdirs @@ -24,6 +25,7 @@ def get_sdv_logger_config(): return logger_conf +@lru_cache() def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. From 09f54c65b66cec1b26b7580214a50578821f67b3 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 23:05:12 +0200 Subject: [PATCH 16/19] Add platformdirs to dependency checker --- Makefile | 2 +- latest_requirements.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index f31c000d1..e628be5d9 100644 --- a/Makefile +++ b/Makefile @@ -265,5 +265,5 @@ release-major: check-release bumpversion-major release .PHONY: check-deps check-deps: - $(eval allow_list='cloudpickle=|graphviz=|numpy=|pandas=|tqdm=|copulas=|ctgan=|deepecho=|rdt=|sdmetrics=') + $(eval allow_list='cloudpickle=|graphviz=|numpy=|pandas=|tqdm=|copulas=|ctgan=|deepecho=|rdt=|sdmetrics=|platformdirs=') pip freeze | grep -v "SDV.git" | grep -E $(allow_list) | sort > $(OUTPUT_FILEPATH) diff --git a/latest_requirements.txt b/latest_requirements.txt index 725809553..ea41aa908 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -8,3 +8,4 @@ pandas==2.2.2 rdt==1.11.1 sdmetrics==0.14.0 tqdm==4.66.2 +platformdirs==4.2.0 From c5bb7bf3a29014b56ff3267489e60eb313047567 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 23:38:40 +0200 Subject: [PATCH 17/19] Remove conditional logging and use back the disable_single_table_logger --- sdv/logging/__init__.py | 3 +- sdv/logging/utils.py | 21 +++++ sdv/multi_table/base.py | 26 +++--- sdv/single_table/base.py | 131 +++++++++++++++---------------- tests/unit/logging/test_utils.py | 20 ++++- 5 files changed, 119 insertions(+), 82 deletions(-) diff --git a/sdv/logging/__init__.py b/sdv/logging/__init__.py index 1080bc9c1..436a1a442 100644 --- a/sdv/logging/__init__.py +++ b/sdv/logging/__init__.py @@ -1,8 +1,9 @@ """Module for configuring loggers within the SDV library.""" -from sdv.logging.utils import get_sdv_logger, get_sdv_logger_config +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config __all__ = ( + 'disable_single_table_logger', 'get_sdv_logger', 'get_sdv_logger_config', ) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 89469c21a..2f6a13be4 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -1,5 +1,6 @@ """Utilities for configuring logging within the SDV library.""" +import contextlib import logging from functools import lru_cache from pathlib import Path @@ -25,6 +26,26 @@ def get_sdv_logger_config(): return logger_conf +@contextlib.contextmanager +def disable_single_table_logger(): + """Temporarily disables logging for the single table synthesizers. + + This context manager temporarily removes all handlers associated with + the ``SingleTableSynthesizer`` logger, disabling logging for that module + within the current context. After the context exits, the + removed handlers are restored to the logger. + """ + # Logging without ``SingleTableSynthesizer`` + single_table_logger = logging.getLogger('SingleTableSynthesizer') + handlers = single_table_logger.handlers + single_table_logger.handlers = [] + try: + yield + finally: + for handler in handlers: + single_table_logger.addHandler(handler) + + @lru_cache() def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 8095cfad9..4afbee0e9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -16,7 +16,7 @@ _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id) from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError -from sdv.logging import get_sdv_logger +from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') @@ -59,14 +59,15 @@ def _set_temp_numpy_seed(self): np.random.set_state(initial_state) def _initialize_models(self): - for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) - self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, - locales=self.locales, - table_name=table_name, - **synthesizer_parameters - ) + with disable_single_table_logger(): + for table_name, table_metadata in self.metadata.tables.items(): + synthesizer_parameters = self._table_parameters.get(table_name, {}) + self._table_synthesizers[table_name] = self._synthesizer( + metadata=table_metadata, + locales=self.locales, + table_name=table_name, + **synthesizer_parameters + ) def _get_pbar_args(self, **kwargs): """Return a dictionary with the updated keyword args for a progress bar.""" @@ -409,8 +410,9 @@ def fit_processed_data(self, processed_data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - augmented_data = self._augment_tables(processed_data) - self._model_tables(augmented_data) + with disable_single_table_logger(): + augmented_data = self._augment_tables(processed_data) + self._model_tables(augmented_data) self._fitted = True self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d') @@ -479,7 +481,7 @@ def sample(self, scale=1.0): raise SynthesizerInputError( f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.") - with self._set_temp_numpy_seed(): + with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) total_rows = 0 diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 393c010e0..41d271fd1 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -112,16 +112,15 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None self._synthesizer_id = generate_synthesizer_id(self) - if self.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nInstance:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - self._synthesizer_id - ) + SYNTHESIZER_LOGGER.info( + '\nInstance:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + self._synthesizer_id + ) def set_address_columns(self, column_names, anonymization_level='full'): """Set the address multi-column transformer.""" @@ -404,22 +403,21 @@ def fit_processed_data(self, processed_data): processed_data (pandas.DataFrame): The transformed data used to fit the model to. """ - if self.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nFit processed data:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(processed_data), - len(processed_data.columns), - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info( + '\nFit processed data:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(processed_data), + len(processed_data.columns), + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: @@ -437,22 +435,21 @@ def fit(self, data): data (pandas.DataFrame): The raw data (before any transformations) to fit the model to. """ - if self.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nFit:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - len(data), - len(data.columns), - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info( + '\nFit:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + len(data), + len(data.columns), + self._synthesizer_id, + ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() @@ -470,16 +467,15 @@ def save(self, filepath): Path where the synthesizer instance will be serialized. """ synthesizer_id = getattr(self, '_synthesizer_id', None) - if self.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nSave:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - self.__class__.__name__, - synthesizer_id - ) + SYNTHESIZER_LOGGER.info( + '\nSave:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + self.__class__.__name__, + synthesizer_id + ) with open(filepath, 'wb') as output: cloudpickle.dump(self, output) @@ -891,22 +887,21 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) - if self.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nSample:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 1\n' - ' Total number of rows: %s\n' - ' Total number of columns: %s\n' - ' Synthesizer id: %s', - sample_timestamp, - self.__class__.__name__, - len(sampled_data), - len(sampled_data.columns), - self._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info( + '\nSample:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: %s\n' + ' Total number of columns: %s\n' + ' Synthesizer id: %s', + sample_timestamp, + self.__class__.__name__, + len(sampled_data), + len(sampled_data.columns), + self._synthesizer_id, + ) return sampled_data diff --git a/tests/unit/logging/test_utils.py b/tests/unit/logging/test_utils.py index 2235779f6..316ae9083 100644 --- a/tests/unit/logging/test_utils.py +++ b/tests/unit/logging/test_utils.py @@ -2,7 +2,7 @@ import logging from unittest.mock import Mock, mock_open, patch -from sdv.logging.utils import get_sdv_logger, get_sdv_logger_config +from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config def test_get_sdv_logger_config(): @@ -39,6 +39,24 @@ def test_get_sdv_logger_config(): } +@patch('sdv.logging.utils.logging.getLogger') +def test_disable_single_table_logger(mock_getlogger): + # Setup + mock_logger = Mock() + handler = Mock() + mock_logger.handlers = [handler] + mock_logger.removeHandler.side_effect = lambda x: mock_logger.handlers.pop(0) + mock_logger.addHandler.side_effect = lambda x: mock_logger.handlers.append(x) + mock_getlogger.return_value = mock_logger + + # Run + with disable_single_table_logger(): + assert len(mock_logger.handlers) == 0 + + # Assert + assert len(mock_logger.handlers) == 1 + + @patch('sdv.logging.utils.logging.StreamHandler') @patch('sdv.logging.utils.logging.getLogger') @patch('sdv.logging.utils.get_sdv_logger_config') From 4afdfa428fc0a9694b8258526be0896208efa589 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 23:44:21 +0200 Subject: [PATCH 18/19] Remove tutorials tests that we no longer have --- Makefile | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/Makefile b/Makefile index e628be5d9..c953b15b6 100644 --- a/Makefile +++ b/Makefile @@ -123,12 +123,8 @@ test-integration: ## run tests quickly with the default Python test-readme: ## run the readme snippets invoke readme -.PHONY: test-tutorials -test-tutorials: ## run the tutorial notebooks - invoke tutorials - .PHONY: test -test: test-unit test-integration test-readme test-tutorials ## test everything that needs test dependencies +test: test-unit test-integration test-readme ## test everything that needs test dependencies .PHONY: test-all test-all: ## run tests on every Python version with tox From 97cf306f08ef1aacf209d4c97db3f12d778583ff Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 23:45:38 +0200 Subject: [PATCH 19/19] Add logging to the __init__ of sdv --- sdv/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdv/__init__.py b/sdv/__init__.py index ef209a09a..4636bcf7d 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -16,8 +16,8 @@ from types import ModuleType from sdv import ( - constraints, data_processing, datasets, evaluation, io, lite, metadata, metrics, multi_table, - sampling, sequential, single_table, version) + constraints, data_processing, datasets, evaluation, io, lite, logging, metadata, metrics, + multi_table, sampling, sequential, single_table, version) __all__ = [ 'constraints', @@ -26,6 +26,7 @@ 'evaluation', 'io', 'lite', + 'logging', 'metadata', 'metrics', 'multi_table',