Skip to content

Commit

Permalink
Fix unit tests and move to logging module
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Apr 18, 2024
1 parent 07a8bae commit d27e3f0
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 74 deletions.
2 changes: 2 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Miscellaneous utility functions."""
import contextlib
import operator
import uuid
import warnings
Expand All @@ -8,6 +9,7 @@
from pathlib import Path

import pandas as pd
import yaml
from pandas.core.tools.datetimes import _guess_datetime_format_for_array

from sdv import version
Expand Down
Empty file added sdv/logging/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions sdv/logging/sdv_logger_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
log_registry: 'local'
version: 1
loggers:
SingleTableSynthesizer:
level: INFO
propagate: false
handlers:
class: logging.FileHandler
filename: sdv_logs.log
MultiTableSynthesizer:
level: INFO
propagate: false
handlers:
class: logging.FileHandler
filename: sdv_logs.log
91 changes: 91 additions & 0 deletions sdv/logging/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import contextlib
import logging
import logging.config
from functools import lru_cache
from pathlib import Path

import yaml


def get_logger_config():
"""Return a dictionary with the logging configuration."""
logging_path = Path(__file__).parent
with open(logging_path / 'sdv_logger_config.yml', 'r') as f:
logger_conf = yaml.safe_load(f)

# Logfile to be in this same directory
for logger in logger_conf.get('loggers', {}).values():
handler = logger.get('handlers', {})
if handler.get('filename') == 'sdv_logs.log':
handler['filename'] = logging_path / handler['filename']

return logger_conf


@contextlib.contextmanager
def disable_single_table_logger():
"""Temporarily disables logging for the single table synthesizers.
This context manager temporarily removes all handlers associated with
the ``SingleTableSynthesizer`` logger, disabling logging for that module
within the current context. After the context exits, the
removed handlers are restored to the logger.
"""
# Logging without ``SingleTableSynthesizer``
single_table_logger = logging.getLogger('SingleTableSynthesizer')
handlers = single_table_logger.handlers
for handler in handlers:
single_table_logger.removeHandler(handler)

try:
yield
finally:
for handler in handlers:
single_table_logger.addHandler(handler)


@lru_cache()
def get_logger(logger_name):
"""Get a logger instance with the specified name and configuration.
This function retrieves or creates a logger instance with the specified name
and applies configuration settings based on the logger's name and the logging
configuration.
Args:
logger_name (str):
The name of the logger to retrieve or create.
Returns:
logging.Logger:
A logger instance configured according to the logging configuration
and the specific settings for the given logger name.
"""
logger_conf = get_logger_config()
logger = logging.getLogger(logger_name)
if logger_name in logger_conf.get('loggers'):
formatter = None
config = logger_conf.get('loggers').get(logger_name)
log_level = getattr(logging, config.get('level', 'INFO'))
if config.get('format'):
formatter = logging.Formatter(config.get('format'))

logger.setLevel(log_level)
logger.propagate = config.get('propagate', False)
handler = config.get('handlers')
handlers = handler.get('class')
handlers = [handlers] if isinstance(handlers, str) else handlers
for handler_class in handlers:
if handler_class == 'logging.FileHandler':
logfile = handler.get('filename')
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
elif handler in ('consoleHandler', 'StreamingHandler'):
ch = logging.StreamHandler()
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger
43 changes: 21 additions & 22 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@
_validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version,
generate_synthesizer_id)
from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError
from sdv.logging.utils import get_logger, disable_single_table_logger
from sdv.single_table.copulas import GaussianCopulaSynthesizer

logger_config_file = Path(__file__).parent.parent
with open(logger_config_file / 'sdv_logger.yml', 'r') as f:
logger_conf = yaml.safe_load(f)

logging.config.dictConfig(logger_conf)
LOGGER = logging.getLogger('BaseMultiTableSynthesizer')
SYNTHESIZER_LOGGER = get_logger('MultiTableSynthesizer')


class BaseMultiTableSynthesizer:
Expand Down Expand Up @@ -66,13 +62,14 @@ def _set_temp_numpy_seed(self):
np.random.set_state(initial_state)

def _initialize_models(self):
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
**synthesizer_parameters
)
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
**synthesizer_parameters
)

def _get_pbar_args(self, **kwargs):
"""Return a dictionary with the updated keyword args for a progress bar."""
Expand Down Expand Up @@ -123,7 +120,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self._fitted_sdv_version = None
self._fitted_sdv_enterprise_version = None
self._synthesizer_id = generate_synthesizer_id(self)
LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nInstance:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -396,7 +393,7 @@ def fit_processed_data(self, processed_data):
total_rows += len(table)
total_columns += len(table.columns)

LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nFit processed data\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand All @@ -413,8 +410,10 @@ def fit_processed_data(self, processed_data):
self._synthesizer_id,
)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)
with disable_single_table_logger():
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)

self._fitted = True
self._fitted_date = datetime.datetime.today().strftime('%Y-%m-%d')
self._fitted_sdv_version = getattr(version, 'public', None)
Expand All @@ -434,7 +433,7 @@ def fit(self, data):
total_rows += len(table)
total_columns += len(table.columns)

LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nFit\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -482,7 +481,7 @@ def sample(self, scale=1.0):
raise SynthesizerInputError(
f"Invalid parameter for 'scale' ({scale}). Please provide a number that is >0.0.")

with self._set_temp_numpy_seed():
with self._set_temp_numpy_seed(), disable_single_table_logger():
sampled_data = self._sample(scale=scale)

total_rows = 0
Expand All @@ -491,7 +490,7 @@ def sample(self, scale=1.0):
total_rows += len(table)
total_columns += len(table.columns)

LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -674,7 +673,7 @@ def save(self, filepath):
with open(filepath, 'wb') as output:
cloudpickle.dump(self, output)

LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nSave:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -704,7 +703,7 @@ def load(cls, filepath):
if getattr(synthesizer, '_synthesizer_id', None) is None:
synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer)

LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nLoad\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down
15 changes: 0 additions & 15 deletions sdv/sdv_logger.yml

This file was deleted.

24 changes: 11 additions & 13 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import functools
import inspect
import logging
import logging.config
import math
import operator
import os
Expand All @@ -18,7 +17,6 @@
import numpy as np
import pandas as pd
import tqdm
import yaml
from copulas.multivariate import GaussianMultivariate

from sdv import version
Expand All @@ -27,14 +25,12 @@
from sdv.constraints.errors import AggregateConstraintsError
from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import ConstraintsNotMetError, InvalidDataError, SynthesizerInputError
from sdv.logging.utils import get_logger
from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path

logger_config_file = Path(__file__).parent.parent
with open(logger_config_file / 'sdv_logger.yml', 'r') as f:
logger_conf = yaml.safe_load(f)

logging.config.dictConfig(logger_conf)
LOGGER = logging.getLogger('BaseSingleTableSynthesizer')
LOGGER = logging.getLogger(__name__)
SYNTHESIZER_LOGGER = get_logger('SingleTableSynthesizer')

COND_IDX = str(uuid.uuid4())
FIXED_RNG_SEED = 73251
Expand Down Expand Up @@ -116,7 +112,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
self._fitted_sdv_version = None
self._fitted_sdv_enterprise_version = None
self._synthesizer_id = generate_synthesizer_id(self)
LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nInstance:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -407,7 +403,7 @@ def fit_processed_data(self, processed_data):
processed_data (pandas.DataFrame):
The transformed data used to fit the model to.
"""
LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nFit processed data\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -438,7 +434,7 @@ def fit(self, data):
data (pandas.DataFrame):
The raw data (before any transformations) to fit the model to.
"""
LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nFit\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -468,7 +464,7 @@ def save(self, filepath):
filepath (str):
Path where the synthesizer instance will be serialized.
"""
LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nSave:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -500,7 +496,7 @@ def load(cls, filepath):
if getattr(synthesizer, '_synthesizer_id', None) is None:
synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer)

LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nLoad\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand Down Expand Up @@ -884,7 +880,7 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
output_file_path,
show_progress_bar=show_progress_bar
)
LOGGER.info(
SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
Expand All @@ -900,6 +896,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
self._synthesizer_id,
)

return sampled_data

def _sample_with_conditions(self, conditions, max_tries_per_batch, batch_size,
progress_bar=None, output_file_path=None):
"""Sample rows with conditions.
Expand Down
Loading

0 comments on commit d27e3f0

Please sign in to comment.