Skip to content

Commit

Permalink
Add integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Apr 24, 2024
1 parent b3a28c0 commit d48aad6
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 3 deletions.
3 changes: 0 additions & 3 deletions sdv/logging/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Utilities for configuring logging within the SDV library."""

import logging
import logging.config
from functools import lru_cache
from pathlib import Path

import platformdirs
Expand All @@ -26,7 +24,6 @@ def get_sdv_logger_config():
return logger_conf


@lru_cache()
def get_sdv_logger(logger_name):
"""Get a logger instance with the specified name and configuration.
Expand Down
74 changes: 74 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import importlib.metadata
import re
import warnings
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
import platformdirs
import pytest
from faker import Faker
from rdt.transformers import FloatFormatter
Expand Down Expand Up @@ -1667,3 +1670,74 @@ def test_hma_relationship_validity():

# Assert
assert report.get_details('Relationship Validity')['Score'].mean() == 1.0


@patch('sdv.multi_table.base.generate_synthesizer_id')
@patch('sdv.multi_table.base.datetime')
def test_synthesizer_logger(mock_datetime, mock_generate_id):
"""Test that the synthesizer logger logs the expected messages."""
# Setup
store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev'))
file_name = 'sdv_logs.log'

synth_id = 'HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
mock_generate_id.return_value = synth_id
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
data, metadata = download_demo('multi_table', 'fake_hotels')

# Run
instance = HMASynthesizer(metadata)

# Assert
with open(store_path / file_name) as f:
instance_lines = f.readlines()[-4:]

assert ''.join(instance_lines) == (
'Instance:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.fit(data)

# Assert
with open(store_path / file_name) as f:
fit_lines = f.readlines()[-17:]

assert ''.join(fit_lines) == (
'Fit:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Statistics of the fit data:\n'
' Total number of tables: 2\n'
' Total number of rows: 668\n'
' Total number of columns: 15\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
'\nFit processed data:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Statistics of the fit processed data:\n'
' Total number of tables: 2\n'
' Total number of rows: 668\n'
' Total number of columns: 11\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.sample(1)
with open(store_path / file_name) as f:
sample_lines = f.readlines()[-8:]

# Assert
assert ''.join(sample_lines) == (
'Sample:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: HMASynthesizer\n'
' Statistics of the sample size:\n'
' Total number of tables: 2\n'
' Total number of rows: 668\n'
' Total number of columns: 15\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)
78 changes: 78 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import importlib.metadata
import re
import warnings
from pathlib import Path
from unittest.mock import patch

import numpy as np
import pandas as pd
import platformdirs
import pytest
from rdt.transformers import AnonymizedFaker, FloatFormatter, RegexGenerator, UniformEncoder

Expand Down Expand Up @@ -777,3 +779,79 @@ def test_fit_raises_version_error():
)
with pytest.raises(VersionError, match=expected_message):
instance.fit(data)


@patch('sdv.single_table.base.generate_synthesizer_id')
@patch('sdv.single_table.base.datetime')
def test_synthesizer_logger(mock_datetime, mock_generate_id):
"""Test that the synthesizer logger logs the expected messages."""
# Setup
store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev'))
file_name = 'sdv_logs.log'

synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
mock_generate_id.return_value = synth_id
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
data = pd.DataFrame({
'col 1': [1, 2, 3],
'col 2': [4, 5, 6],
'col 3': ['a', 'b', 'c'],
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)

# Run
instance = GaussianCopulaSynthesizer(metadata)

# Assert
with open(store_path / file_name) as f:
instance_lines = f.readlines()[-4:]

assert ''.join(instance_lines) == (
'Instance:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.fit(data)

# Assert
with open(store_path / file_name) as f:
fit_lines = f.readlines()[-17:]

assert ''.join(fit_lines) == (
'Fit:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Statistics of the fit data:\n'
' Total number of tables: 1\n'
' Total number of rows: 3\n'
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
'\nFit processed data:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Statistics of the fit processed data:\n'
' Total number of tables: 1\n'
' Total number of rows: 3\n'
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

# Run
instance.sample(100)
with open(store_path / file_name) as f:
sample_lines = f.readlines()[-8:]

assert ''.join(sample_lines) == (
'Sample:\n'
' Timestamp: 2024-04-19 16:20:10.037183\n'
' Synthesizer class name: GaussianCopulaSynthesizer\n'
' Statistics of the sample size:\n'
' Total number of tables: 1\n'
' Total number of rows: 100\n'
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)

0 comments on commit d48aad6

Please sign in to comment.