From d48aad6254ae4b377c932b5fd0ea714485fcef32 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev Date: Wed, 24 Apr 2024 18:50:53 +0200 Subject: [PATCH] Add integration tests --- sdv/logging/utils.py | 3 - tests/integration/multi_table/test_hma.py | 74 +++++++++++++++++++ tests/integration/single_table/test_base.py | 78 +++++++++++++++++++++ 3 files changed, 152 insertions(+), 3 deletions(-) diff --git a/sdv/logging/utils.py b/sdv/logging/utils.py index 80abcc2f9..a7aae8e04 100644 --- a/sdv/logging/utils.py +++ b/sdv/logging/utils.py @@ -1,8 +1,6 @@ """Utilities for configuring logging within the SDV library.""" import logging -import logging.config -from functools import lru_cache from pathlib import Path import platformdirs @@ -26,7 +24,6 @@ def get_sdv_logger_config(): return logger_conf -@lru_cache() def get_sdv_logger(logger_name): """Get a logger instance with the specified name and configuration. diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index bf5405d04..8c2a97ac6 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -2,9 +2,12 @@ import importlib.metadata import re import warnings +from pathlib import Path +from unittest.mock import patch import numpy as np import pandas as pd +import platformdirs import pytest from faker import Faker from rdt.transformers import FloatFormatter @@ -1667,3 +1670,74 @@ def test_hma_relationship_validity(): # Assert assert report.get_details('Relationship Validity')['Score'].mean() == 1.0 + + +@patch('sdv.multi_table.base.generate_synthesizer_id') +@patch('sdv.multi_table.base.datetime') +def test_synthesizer_logger(mock_datetime, mock_generate_id): + """Test that the synthesizer logger logs the expected messages.""" + # Setup + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + file_name = 'sdv_logs.log' + + synth_id = 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_generate_id.return_value = synth_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' + data, metadata = download_demo('multi_table', 'fake_hotels') + + # Run + instance = HMASynthesizer(metadata) + + # Assert + with open(store_path / file_name) as f: + instance_lines = f.readlines()[-4:] + + assert ''.join(instance_lines) == ( + 'Instance:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.fit(data) + + # Assert + with open(store_path / file_name) as f: + fit_lines = f.readlines()[-17:] + + assert ''.join(fit_lines) == ( + 'Fit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 15\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 11\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.sample(1) + with open(store_path / file_name) as f: + sample_lines = f.readlines()[-8:] + + # Assert + assert ''.join(sample_lines) == ( + 'Sample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: HMASynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 2\n' + ' Total number of rows: 668\n' + ' Total number of columns: 15\n' + ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 73b25c97c..8c7ea2601 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -2,10 +2,12 @@ import importlib.metadata import re import warnings +from pathlib import Path from unittest.mock import patch import numpy as np import pandas as pd +import platformdirs import pytest from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder @@ -777,3 +779,79 @@ def test_fit_raises_version_error(): ) with pytest.raises(VersionError, match=expected_message): instance.fit(data) + + +@patch('sdv.single_table.base.generate_synthesizer_id') +@patch('sdv.single_table.base.datetime') +def test_synthesizer_logger(mock_datetime, mock_generate_id): + """Test that the synthesizer logger logs the expected messages.""" + # Setup + store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) + file_name = 'sdv_logs.log' + + synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' + mock_generate_id.return_value = synth_id + mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' + data = pd.DataFrame({ + 'col 1': [1, 2, 3], + 'col 2': [4, 5, 6], + 'col 3': ['a', 'b', 'c'], + }) + metadata = SingleTableMetadata() + metadata.detect_from_dataframe(data) + + # Run + instance = GaussianCopulaSynthesizer(metadata) + + # Assert + with open(store_path / file_name) as f: + instance_lines = f.readlines()[-4:] + + assert ''.join(instance_lines) == ( + 'Instance:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.fit(data) + + # Assert + with open(store_path / file_name) as f: + fit_lines = f.readlines()[-17:] + + assert ''.join(fit_lines) == ( + 'Fit:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the fit data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + '\nFit processed data:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the fit processed data:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 3\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + ) + + # Run + instance.sample(100) + with open(store_path / file_name) as f: + sample_lines = f.readlines()[-8:] + + assert ''.join(sample_lines) == ( + 'Sample:\n' + ' Timestamp: 2024-04-19 16:20:10.037183\n' + ' Synthesizer class name: GaussianCopulaSynthesizer\n' + ' Statistics of the sample size:\n' + ' Total number of tables: 1\n' + ' Total number of rows: 100\n' + ' Total number of columns: 3\n' + ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' + )