Skip to content

Commit

Permalink
Include sdv_logger_config.yml with the package (#1963)
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer authored Apr 29, 2024
1 parent 275955d commit 8aa3b5e
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 81 deletions.
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ dependencies = [
'botocore>=1.31',
'cloudpickle>=2.1.0',
'graphviz>=0.13.2',
"numpy>=1.20.0;python_version<'3.10'",
"numpy>=1.21.0;python_version<'3.10'",
"numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'",
"numpy>=1.26.0,<2;python_version>='3.12'",
"pandas>=1.1.3;python_version<'3.10'",
"pandas>=1.3.4;python_version>='3.10' and python_version<'3.11'",
"pandas>=1.4.0;python_version<'3.10'",
"pandas>=1.4.0;python_version>='3.10' and python_version<'3.11'",
"pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'",
"pandas>=2.1.1;python_version>='3.12'",
'tqdm>=4.29',
Expand Down Expand Up @@ -141,7 +141,8 @@ namespaces = false
'make.bat',
'*.jpg',
'*.png',
'*.gif'
'*.gif',
'sdv_logger_config.yml'
]

[tool.setuptools.exclude-package-data]
Expand Down
3 changes: 0 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def _initialize_models(self):
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata,
locales=self.locales,
table_name=table_name,
**synthesizer_parameters
)

Expand Down Expand Up @@ -200,8 +199,6 @@ def set_table_parameters(self, table_name, table_parameters):
A dictionary with the parameters as keys and the values to be used to instantiate
the table's synthesizer.
"""
# Ensure that we set the name of the table no matter what
table_parameters.update({'table_name': table_name})
self._table_synthesizers[table_name] = self._synthesizer(
metadata=self.metadata.tables[table_name],
**table_parameters
Expand Down
12 changes: 3 additions & 9 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,11 +312,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
row = pd.Series({'num_rows': len(child_rows)})
row.index = f'__{child_name}__{foreign_key}__' + row.index
else:
synthesizer_parameters = self._table_parameters[child_name]
synthesizer_parameters.update({'table_name': child_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
**self._table_parameters[child_name]
)
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
Expand Down Expand Up @@ -523,11 +521,9 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})

table_meta = self.metadata.tables[child_name]
synthesizer_parameters = self._table_parameters[child_name]
synthesizer_parameters.update({'table_name': child_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
**self._table_parameters[child_name]
)
synthesizer._set_parameters(parameters, default_parameters)
synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor
Expand Down Expand Up @@ -622,11 +618,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key):
for parent_id, row in parent_rows.iterrows():
parameters = self._extract_parameters(row, table_name, foreign_key)
table_meta = self._table_synthesizers[table_name].get_metadata()
synthesizer_parameters = self._table_parameters[table_name]
synthesizer_parameters.update({'table_name': table_name})
synthesizer = self._synthesizer(
table_meta,
**synthesizer_parameters
**self._table_parameters[table_name]
)
synthesizer._set_parameters(parameters)
try:
Expand Down
2 changes: 1 addition & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False
metadata=metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
locales=locales
locales=locales,
)

sequence_key = self.metadata.sequence_key
Expand Down
23 changes: 10 additions & 13 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,19 @@ def _check_metadata_updated(self):
)

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
locales=['en_US'], table_name=None):
locales=['en_US']):
self._validate_inputs(enforce_min_max_values, enforce_rounding)
self.metadata = metadata
self.metadata.validate()
self._check_metadata_updated()
self.enforce_min_max_values = enforce_min_max_values
self.enforce_rounding = enforce_rounding
self.locales = locales
self.table_name = table_name
self._data_processor = DataProcessor(
metadata=self.metadata,
enforce_rounding=self.enforce_rounding,
enforce_min_max_values=self.enforce_min_max_values,
locales=self.locales,
table_name=self.table_name
)
self._fitted = False
self._random_state_set = False
Expand Down Expand Up @@ -500,16 +498,15 @@ def load(cls, filepath):
if getattr(synthesizer, '_synthesizer_id', None) is None:
synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer)

if synthesizer.table_name is None:
SYNTHESIZER_LOGGER.info(
'\nLoad:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
synthesizer.__class__.__name__,
synthesizer._synthesizer_id,
)
SYNTHESIZER_LOGGER.info(
'\nLoad:\n'
' Timestamp: %s\n'
' Synthesizer class name: %s\n'
' Synthesizer id: %s',
datetime.datetime.now(),
synthesizer.__class__.__name__,
synthesizer._synthesizer_id,
)

return synthesizer

Expand Down
4 changes: 1 addition & 3 deletions sdv/single_table/copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6,
discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500,
discriminator_steps=1, log_frequency=True, verbose=False, epochs=300,
pac=10, cuda=True, numerical_distributions=None, default_distribution=None,
table_name=None):
pac=10, cuda=True, numerical_distributions=None, default_distribution=None):

super().__init__(
metadata,
Expand All @@ -143,7 +142,6 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
epochs=epochs,
pac=pac,
cuda=cuda,
table_name=table_name
)

validate_numerical_distributions(numerical_distributions, self.metadata.columns)
Expand Down
4 changes: 1 addition & 3 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,12 @@ def get_distribution_class(cls, distribution):
return cls._DISTRIBUTIONS[distribution]

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
locales=['en_US'], numerical_distributions=None, default_distribution=None,
table_name=None):
locales=['en_US'], numerical_distributions=None, default_distribution=None):
super().__init__(
metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
locales=locales,
table_name=table_name
)
validate_numerical_distributions(numerical_distributions, self.metadata.columns)
self.numerical_distributions = numerical_distributions or {}
Expand Down
7 changes: 2 additions & 5 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,13 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6,
discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500,
discriminator_steps=1, log_frequency=True, verbose=False, epochs=300,
pac=10, cuda=True, table_name=None):
pac=10, cuda=True):

super().__init__(
metadata=metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
locales=locales,
table_name=table_name
)

self.embedding_dim = embedding_dim
Expand Down Expand Up @@ -339,14 +338,12 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer):

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128),
l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True,
table_name=None):
l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True):

super().__init__(
metadata=metadata,
enforce_min_max_values=enforce_min_max_values,
enforce_rounding=enforce_rounding,
table_name=table_name
)
self.embedding_dim = embedding_dim
self.compress_dims = compress_dims
Expand Down
9 changes: 3 additions & 6 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def test_hma_set_table_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'table_name': 'characters'
'numerical_distributions': {}
}
families_params = hmasynthesizer.get_table_parameters('families')
assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
Expand All @@ -161,8 +160,7 @@ def test_hma_set_table_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'table_name': 'families'
'numerical_distributions': {}
}
char_families_params = hmasynthesizer.get_table_parameters('character_families')
assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer'
Expand All @@ -171,8 +169,7 @@ def test_hma_set_table_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'numerical_distributions': {},
'table_name': 'character_families'
'numerical_distributions': {}
}

assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma'
Expand Down
15 changes: 4 additions & 11 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ def test__initialize_models(self):
}
instance._synthesizer.assert_has_calls([
call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma',
locales=locales, table_name='nesreca'),
call(metadata=instance.metadata.tables['oseba'], locales=locales, table_name='oseba'),
call(metadata=instance.metadata.tables['upravna_enota'], locales=locales,
table_name='upravna_enota')
locales=locales),
call(metadata=instance.metadata.tables['oseba'], locales=locales),
call(metadata=instance.metadata.tables['upravna_enota'], locales=locales)
])

def test__get_pbar_args(self):
Expand Down Expand Up @@ -280,7 +279,6 @@ def test_get_table_parameters_empty(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'table_name': 'oseba',
'numerical_distributions': {}
}
}
Expand All @@ -301,7 +299,6 @@ def test_get_table_parameters_has_parameters(self):
'enforce_min_max_values': True,
'enforce_rounding': True,
'locales': ['en_US'],
'table_name': 'oseba',
'numerical_distributions': {}
}

Expand Down Expand Up @@ -333,17 +330,13 @@ def test_set_table_parameters(self):

# Assert
table_parameters = instance.get_table_parameters('oseba')
assert instance._table_parameters['oseba'] == {
'default_distribution': 'gamma',
'table_name': 'oseba'
}
assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'}
assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer'
assert table_parameters['synthesizer_parameters'] == {
'default_distribution': 'gamma',
'enforce_min_max_values': True,
'locales': ['en_US'],
'enforce_rounding': True,
'table_name': 'oseba',
'numerical_distributions': {}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def test__recreate_child_synthesizer(self):
# Assert
assert synthesizer == instance._synthesizer.return_value
assert synthesizer._data_processor == table_synthesizer._data_processor
instance._synthesizer.assert_called_once_with(table_meta, table_name='users', a=1)
instance._synthesizer.assert_called_once_with(table_meta, a=1)
synthesizer._set_parameters.assert_called_once_with(
instance._extract_parameters.return_value,
{'colA': 'default_param', 'colB': 'default_param'}
Expand Down
19 changes: 5 additions & 14 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor,
metadata=metadata,
enforce_rounding=instance.enforce_rounding,
enforce_min_max_values=instance.enforce_min_max_values,
locales=instance.locales,
table_name=None
locales=instance.locales
)
metadata.validate.assert_called_once_with()
mock_check_metadata_updated.assert_called_once()
Expand Down Expand Up @@ -124,8 +123,7 @@ def test___init__custom(self, mock_data_processor):
metadata=metadata,
enforce_rounding=instance.enforce_rounding,
enforce_min_max_values=instance.enforce_min_max_values,
locales=instance.locales,
table_name=None
locales=instance.locales
)
metadata.validate.assert_called_once_with()

Expand Down Expand Up @@ -184,8 +182,7 @@ def test_get_parameters(self, mock_data_processor):
assert parameters == {
'enforce_min_max_values': False,
'enforce_rounding': False,
'locales': 'en_CA',
'table_name': None
'locales': 'en_CA'
}

@patch('sdv.single_table.base.DataProcessor')
Expand Down Expand Up @@ -362,8 +359,7 @@ def test_fit_processed_data(self, mock_datetime, caplog):
instance = Mock(
_fitted_sdv_version=None,
_fitted_sdv_enterprise_version=None,
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
)
processed_data = pd.DataFrame({'column_a': [1, 2, 3]})

Expand Down Expand Up @@ -394,7 +390,6 @@ def test_fit_processed_data_raises_version_error(self):
instance = Mock(
_fitted_sdv_version='1.0.0',
_fitted_sdv_enterprise_version=None,
table_name=None
)
processed_data = pd.DataFrame({'column_a': [1, 2, 3]})
instance._random_state_set = True
Expand Down Expand Up @@ -422,7 +417,6 @@ def test_fit(self, mock_datetime, caplog):
_fitted_sdv_version=None,
_fitted_sdv_enterprise_version=None,
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
)
data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']})
instance._random_state_set = True
Expand Down Expand Up @@ -459,7 +453,6 @@ def test_fit_raises_version_error(self):
instance = Mock(
_fitted_sdv_version='1.0.0',
_fitted_sdv_enterprise_version=None,
table_name=None
)
data = pd.DataFrame({'column_a': [1, 2, 3]})
instance._random_state_set = True
Expand Down Expand Up @@ -1417,7 +1410,6 @@ def test_sample(self, mock_datetime, caplog):
output_file_path = 'temp.csv'
instance = Mock(
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
)
instance.get_metadata.return_value._constraints = False
instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]})
Expand Down Expand Up @@ -1810,7 +1802,6 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog):
# Setup
synthesizer = Mock(
_synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
table_name=None
)
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'

Expand Down Expand Up @@ -1839,7 +1830,7 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war
mock_datetime, caplog):
"""Test that the ``load`` method loads a stored synthesizer."""
# Setup
synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None, table_name=None)
synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None)
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5'
mock_generate_synthesizer_id.return_value = synthesizer_id
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/single_table/test_copulagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ def test_get_params(self):
'pac': 10,
'cuda': True,
'numerical_distributions': {},
'default_distribution': 'beta',
'table_name': None
'default_distribution': 'beta'
}

@patch('sdv.single_table.copulagan.rdt')
Expand Down
Loading

0 comments on commit 8aa3b5e

Please sign in to comment.