Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the logger settings #1981

Merged
merged 6 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sdv/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module for configuring loggers within the SDV library."""

from sdv.logging.utils import disable_single_table_logger, get_sdv_logger, get_sdv_logger_config
from sdv.logging.logger import get_sdv_logger
from sdv.logging.utils import disable_single_table_logger, get_sdv_logger_config

__all__ = (
'disable_single_table_logger',
Expand Down
62 changes: 62 additions & 0 deletions sdv/logging/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""SDV Logger."""

import logging
from functools import lru_cache

from sdv.logging.utils import get_sdv_logger_config


@lru_cache()
def get_sdv_logger(logger_name):
"""Get a logger instance with the specified name and configuration.

This function retrieves or creates a logger instance with the specified name
and applies configuration settings based on the logger's name and the logging
configuration.

Args:
logger_name (str):
The name of the logger to retrieve or create.

Returns:
logging.Logger:
A logger instance configured according to the logging configuration
and the specific settings for the given logger name.
"""
logger_conf = get_sdv_logger_config()
logger = logging.getLogger(logger_name)
if logger_conf.get('log_registry') is None:
# Return a logger without any extra settings and avoid writing into files or other streams
return logger

if logger_conf.get('log_registry') == 'local':
for handler in logger.handlers:
# Remove handlers that could exist previously
logger.removeHandler(handler)

if logger_name in logger_conf.get('loggers'):
formatter = None
config = logger_conf.get('loggers').get(logger_name)
log_level = getattr(logging, config.get('level', 'INFO'))
if config.get('format'):
formatter = logging.Formatter(config.get('format'))

logger.setLevel(log_level)
logger.propagate = config.get('propagate', False)
handler = config.get('handlers')
handlers = handler.get('class')
handlers = [handlers] if isinstance(handlers, str) else handlers
for handler_class in handlers:
if handler_class == 'logging.FileHandler':
logfile = handler.get('filename')
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'):
ch = logging.StreamHandler()
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger
2 changes: 1 addition & 1 deletion sdv/logging/sdv_logger_config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
log_registry: 'local'
log_registry: null
version: 1
loggers:
SingleTableSynthesizer:
Expand Down
69 changes: 11 additions & 58 deletions sdv/logging/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import contextlib
import logging
from functools import lru_cache
import shutil
from pathlib import Path

import platformdirs
Expand All @@ -11,13 +11,18 @@

def get_sdv_logger_config():
"""Return a dictionary with the logging configuration."""
logging_path = Path(__file__).parent
with open(logging_path / 'sdv_logger_config.yml', 'r') as f:
logger_conf = yaml.safe_load(f)

# Logfile to be in this same directory
store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev'))
store_path.mkdir(parents=True, exist_ok=True)
config_path = Path(__file__).parent / 'sdv_logger_config.yml'

if (store_path / 'sdv_logger_config.yml').exists():
config_path = store_path / 'sdv_logger_config.yml'
else:
shutil.copyfile(config_path, store_path / 'sdv_logger_config.yml')

with open(config_path, 'r') as f:
logger_conf = yaml.safe_load(f)

for logger in logger_conf.get('loggers', {}).values():
handler = logger.get('handlers', {})
if handler.get('filename') == 'sdv_logs.log':
Expand All @@ -44,55 +49,3 @@ def disable_single_table_logger():
finally:
for handler in handlers:
single_table_logger.addHandler(handler)


@lru_cache()
def get_sdv_logger(logger_name):
"""Get a logger instance with the specified name and configuration.

This function retrieves or creates a logger instance with the specified name
and applies configuration settings based on the logger's name and the logging
configuration.

Args:
logger_name (str):
The name of the logger to retrieve or create.

Returns:
logging.Logger:
A logger instance configured according to the logging configuration
and the specific settings for the given logger name.
"""
logger_conf = get_sdv_logger_config()
if logger_conf.get('log_registry') is None:
# Return a logger without any extra settings and avoid writing into files or other streams
return logging.getLogger(logger_name)

if logger_conf.get('log_registry') == 'local':
logger = logging.getLogger(logger_name)
if logger_name in logger_conf.get('loggers'):
formatter = None
config = logger_conf.get('loggers').get(logger_name)
log_level = getattr(logging, config.get('level', 'INFO'))
if config.get('format'):
formatter = logging.Formatter(config.get('format'))

logger.setLevel(log_level)
logger.propagate = config.get('propagate', False)
handler = config.get('handlers')
handlers = handler.get('class')
handlers = [handlers] if isinstance(handlers, str) else handlers
for handler_class in handlers:
if handler_class == 'logging.FileHandler':
logfile = handler.get('filename')
file_handler = logging.FileHandler(logfile)
file_handler.setLevel(log_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
elif handler_class in ('logging.consoleHandler', 'logging.StreamHandler'):
ch = logging.StreamHandler()
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)

return logger
2 changes: 1 addition & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from sdv.data_processing.data_processor import DataProcessor
from sdv.errors import (
ConstraintsNotMetError, InvalidDataError, SamplingError, SynthesizerInputError)
from sdv.logging.utils import get_sdv_logger
from sdv.logging import get_sdv_logger
from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path

LOGGER = logging.getLogger(__name__)
Expand Down
74 changes: 0 additions & 74 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import importlib.metadata
import re
import warnings
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
import platformdirs
import pytest
from faker import Faker
from rdt.transformers import FloatFormatter
Expand Down Expand Up @@ -1682,74 +1679,3 @@ def test_hma_not_fit_raises_sampling_error():
)
with pytest.raises(SamplingError, match=error_msg):
synthesizer.sample(1)


@patch('sdv.multi_table.base.generate_synthesizer_id')
@patch('sdv.multi_table.base.datetime')
def test_synthesizer_logger(mock_datetime, mock_generate_id):
"""Test that the synthesizer logger logs the expected messages."""
# Setup
store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev'))
file_name = 'sdv_logs.log'

synth_id = 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
mock_generate_id.return_value = synth_id
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
data, metadata = download_demo('multi_table', 'fake_hotels')

# Run
instance = HMASynthesizer(metadata)

# Assert
with open(store_path / file_name) as f:
instance_lines = f.readlines()[-4:]

assert ''.join(instance_lines) == (
'Instance:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.fit(data)

# Assert
with open(store_path / file_name) as f:
fit_lines = f.readlines()[-17:]

assert ''.join(fit_lines) == (
'Fit:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Statistics of the fit data:\n'
' Total number of tables: 2\n'
' Total number of rows: 668\n'
' Total number of columns: 15\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
'\nFit processed data:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Statistics of the fit processed data:\n'
' Total number of tables: 2\n'
' Total number of rows: 668\n'
' Total number of columns: 11\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.sample(1)
with open(store_path / file_name) as f:
sample_lines = f.readlines()[-8:]

# Assert
assert ''.join(sample_lines) == (
'Sample:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Statistics of the sample size:\n'
' Total number of tables: 2\n'
' Total number of rows: 668\n'
' Total number of columns: 15\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)
78 changes: 0 additions & 78 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import importlib.metadata
import re
import warnings
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
import platformdirs
import pytest
from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder

Expand Down Expand Up @@ -781,82 +779,6 @@ def test_fit_raises_version_error():
instance.fit(data)


@patch('sdv.single_table.base.generate_synthesizer_id')
@patch('sdv.single_table.base.datetime')
def test_synthesizer_logger(mock_datetime, mock_generate_id):
"""Test that the synthesizer logger logs the expected messages."""
# Setup
store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev'))
file_name = 'sdv_logs.log'

synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
mock_generate_id.return_value = synth_id
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
data = pd.DataFrame({
'col 1': [1, 2, 3],
'col 2': [4, 5, 6],
'col 3': ['a', 'b', 'c'],
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)

# Run
instance = GaussianCopulaSynthesizer(metadata)

# Assert
with open(store_path / file_name) as f:
instance_lines = f.readlines()[-4:]

assert ''.join(instance_lines) == (
'Instance:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.fit(data)

# Assert
with open(store_path / file_name) as f:
fit_lines = f.readlines()[-17:]

assert ''.join(fit_lines) == (
'Fit:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Statistics of the fit data:\n'
' Total number of tables: 1\n'
' Total number of rows: 3\n'
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
'\nFit processed data:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Statistics of the fit processed data:\n'
' Total number of tables: 1\n'
' Total number of rows: 3\n'
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.sample(100)
with open(store_path / file_name) as f:
sample_lines = f.readlines()[-8:]

assert ''.join(sample_lines) == (
'Sample:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Statistics of the sample size:\n'
' Total number of tables: 1\n'
' Total number of rows: 100\n'
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)


@pytest.mark.parametrize('synthesizer', SYNTHESIZERS)
def test_sample_not_fitted(synthesizer):
"""Test that a synthesizer raises an error when trying to sample without fitting."""
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/logging/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Test ``SDV`` logger."""
import logging
from unittest.mock import Mock, patch

from sdv.logging.logger import get_sdv_logger


@patch('sdv.logging.logger.logging.StreamHandler')
@patch('sdv.logging.logger.logging.getLogger')
@patch('sdv.logging.logger.get_sdv_logger_config')
def test_get_sdv_logger(mock_get_sdv_logger_config, mock_getlogger, mock_streamhandler):
# Setup
mock_logger_conf = {
'log_registry': 'local',
'loggers': {
'test_logger': {
'level': 'DEBUG',
'handlers': {
'class': 'logging.StreamHandler'
}
}
}
}
mock_get_sdv_logger_config.return_value = mock_logger_conf
mock_logger_instance = Mock()
mock_logger_instance.handlers = []
mock_getlogger.return_value = mock_logger_instance

# Run
get_sdv_logger('test_logger')

# Assert
mock_logger_instance.setLevel.assert_called_once_with(logging.DEBUG)
mock_logger_instance.addHandler.assert_called_once()
Loading
Loading