From 8aa3b5ef86f807fbd519d9709c22740b041d8852 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Mon, 29 Apr 2024 12:50:04 +0200 Subject: [PATCH] Include `sdv_logger_config.yml` with the package (#1963) --- pyproject.toml | 9 +++++---- sdv/multi_table/base.py | 3 --- sdv/multi_table/hma.py | 12 +++--------- sdv/sequential/par.py | 2 +- sdv/single_table/base.py | 23 ++++++++++------------- sdv/single_table/copulagan.py | 4 +--- sdv/single_table/copulas.py | 4 +--- sdv/single_table/ctgan.py | 7 ++----- tests/integration/multi_table/test_hma.py | 9 +++------ tests/unit/multi_table/test_base.py | 15 ++++----------- tests/unit/multi_table/test_hma.py | 2 +- tests/unit/single_table/test_base.py | 19 +++++-------------- tests/unit/single_table/test_copulagan.py | 3 +-- tests/unit/single_table/test_copulas.py | 3 +-- tests/unit/single_table/test_ctgan.py | 6 ++---- 15 files changed, 40 insertions(+), 81 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e75e0b4b..9326308d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,11 +25,11 @@ dependencies = [ 'botocore>=1.31', 'cloudpickle>=2.1.0', 'graphviz>=0.13.2', - "numpy>=1.20.0;python_version<'3.10'", + "numpy>=1.21.0;python_version<'3.10'", "numpy>=1.23.3,<2;python_version>='3.10' and python_version<'3.12'", "numpy>=1.26.0,<2;python_version>='3.12'", - "pandas>=1.1.3;python_version<'3.10'", - "pandas>=1.3.4;python_version>='3.10' and python_version<'3.11'", + "pandas>=1.4.0;python_version<'3.10'", + "pandas>=1.4.0;python_version>='3.10' and python_version<'3.11'", "pandas>=1.5.0;python_version>='3.11' and python_version<'3.12'", "pandas>=2.1.1;python_version>='3.12'", 'tqdm>=4.29', @@ -141,7 +141,8 @@ namespaces = false 'make.bat', '*.jpg', '*.png', - '*.gif' + '*.gif', + 'sdv_logger_config.yml' ] [tool.setuptools.exclude-package-data] diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 4afbee0e9..00efe700e 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -65,7 +65,6 @@ def _initialize_models(self): self._table_synthesizers[table_name] = self._synthesizer( metadata=table_metadata, locales=self.locales, - table_name=table_name, **synthesizer_parameters ) @@ -200,8 +199,6 @@ def set_table_parameters(self, table_name, table_parameters): A dictionary with the parameters as keys and the values to be used to instantiate the table's synthesizer. """ - # Ensure that we set the name of the table no matter what - table_parameters.update({'table_name': table_name}) self._table_synthesizers[table_name] = self._synthesizer( metadata=self.metadata.tables[table_name], **table_parameters diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 1e9ce6c2a..9f4d5da30 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -312,11 +312,9 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: - synthesizer_parameters = self._table_parameters[child_name] - synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - **synthesizer_parameters + **self._table_parameters[child_name] ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) row = synthesizer._get_parameters() @@ -523,11 +521,9 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row): default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {}) table_meta = self.metadata.tables[child_name] - synthesizer_parameters = self._table_parameters[child_name] - synthesizer_parameters.update({'table_name': child_name}) synthesizer = self._synthesizer( table_meta, - **synthesizer_parameters + **self._table_parameters[child_name] ) synthesizer._set_parameters(parameters, default_parameters) synthesizer._data_processor = self._table_synthesizers[child_name]._data_processor @@ -622,11 +618,9 @@ def _get_likelihoods(self, table_rows, parent_rows, table_name, foreign_key): for parent_id, row in parent_rows.iterrows(): parameters = self._extract_parameters(row, table_name, foreign_key) table_meta = self._table_synthesizers[table_name].get_metadata() - synthesizer_parameters = self._table_parameters[table_name] - synthesizer_parameters.update({'table_name': table_name}) synthesizer = self._synthesizer( table_meta, - **synthesizer_parameters + **self._table_parameters[table_name] ) synthesizer._set_parameters(parameters) try: diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index e06c96864..4c7a80a36 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -92,7 +92,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, - locales=locales + locales=locales, ) sequence_key = self.metadata.sequence_key diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 41d271fd1..60656a4a2 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -88,7 +88,7 @@ def _check_metadata_updated(self): ) def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], table_name=None): + locales=['en_US']): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata self.metadata.validate() @@ -96,13 +96,11 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self.enforce_min_max_values = enforce_min_max_values self.enforce_rounding = enforce_rounding self.locales = locales - self.table_name = table_name self._data_processor = DataProcessor( metadata=self.metadata, enforce_rounding=self.enforce_rounding, enforce_min_max_values=self.enforce_min_max_values, locales=self.locales, - table_name=self.table_name ) self._fitted = False self._random_state_set = False @@ -500,16 +498,15 @@ def load(cls, filepath): if getattr(synthesizer, '_synthesizer_id', None) is None: synthesizer._synthesizer_id = generate_synthesizer_id(synthesizer) - if synthesizer.table_name is None: - SYNTHESIZER_LOGGER.info( - '\nLoad:\n' - ' Timestamp: %s\n' - ' Synthesizer class name: %s\n' - ' Synthesizer id: %s', - datetime.datetime.now(), - synthesizer.__class__.__name__, - synthesizer._synthesizer_id, - ) + SYNTHESIZER_LOGGER.info( + '\nLoad:\n' + ' Timestamp: %s\n' + ' Synthesizer class name: %s\n' + ' Synthesizer id: %s', + datetime.datetime.now(), + synthesizer.__class__.__name__, + synthesizer._synthesizer_id, + ) return synthesizer diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index 63b22d22b..c9309b45c 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -121,8 +121,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, numerical_distributions=None, default_distribution=None, - table_name=None): + pac=10, cuda=True, numerical_distributions=None, default_distribution=None): super().__init__( metadata, @@ -143,7 +142,6 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, epochs=epochs, pac=pac, cuda=cuda, - table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index c19b7536d..4fc213949 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -91,14 +91,12 @@ def get_distribution_class(cls, distribution): return cls._DISTRIBUTIONS[distribution] def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, - locales=['en_US'], numerical_distributions=None, default_distribution=None, - table_name=None): + locales=['en_US'], numerical_distributions=None, default_distribution=None): super().__init__( metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, locales=locales, - table_name=table_name ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) self.numerical_distributions = numerical_distributions or {} diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index d59c3fca0..860c66487 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -155,14 +155,13 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, log_frequency=True, verbose=False, epochs=300, - pac=10, cuda=True, table_name=None): + pac=10, cuda=True): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, locales=locales, - table_name=table_name ) self.embedding_dim = embedding_dim @@ -339,14 +338,12 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True, - table_name=None): + l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True): super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, - table_name=table_name ) self.embedding_dim = embedding_dim self.compress_dims = compress_dims diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 8c2a97ac6..f7c499d28 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -151,8 +151,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {}, - 'table_name': 'characters' + 'numerical_distributions': {} } families_params = hmasynthesizer.get_table_parameters('families') assert families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -161,8 +160,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {}, - 'table_name': 'families' + 'numerical_distributions': {} } char_families_params = hmasynthesizer.get_table_parameters('character_families') assert char_families_params['synthesizer_name'] == 'GaussianCopulaSynthesizer' @@ -171,8 +169,7 @@ def test_hma_set_table_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'numerical_distributions': {}, - 'table_name': 'character_families' + 'numerical_distributions': {} } assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index b4de98e6a..0cad058b0 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -52,10 +52,9 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales, table_name='nesreca'), - call(metadata=instance.metadata.tables['oseba'], locales=locales, table_name='oseba'), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales, - table_name='upravna_enota') + locales=locales), + call(metadata=instance.metadata.tables['oseba'], locales=locales), + call(metadata=instance.metadata.tables['upravna_enota'], locales=locales) ]) def test__get_pbar_args(self): @@ -280,7 +279,6 @@ def test_get_table_parameters_empty(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'table_name': 'oseba', 'numerical_distributions': {} } } @@ -301,7 +299,6 @@ def test_get_table_parameters_has_parameters(self): 'enforce_min_max_values': True, 'enforce_rounding': True, 'locales': ['en_US'], - 'table_name': 'oseba', 'numerical_distributions': {} } @@ -333,17 +330,13 @@ def test_set_table_parameters(self): # Assert table_parameters = instance.get_table_parameters('oseba') - assert instance._table_parameters['oseba'] == { - 'default_distribution': 'gamma', - 'table_name': 'oseba' - } + assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'} assert table_parameters['synthesizer_name'] == 'GaussianCopulaSynthesizer' assert table_parameters['synthesizer_parameters'] == { 'default_distribution': 'gamma', 'enforce_min_max_values': True, 'locales': ['en_US'], 'enforce_rounding': True, - 'table_name': 'oseba', 'numerical_distributions': {} } diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 6bbf7ef6f..c40e7b080 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -502,7 +502,7 @@ def test__recreate_child_synthesizer(self): # Assert assert synthesizer == instance._synthesizer.return_value assert synthesizer._data_processor == table_synthesizer._data_processor - instance._synthesizer.assert_called_once_with(table_meta, table_name='users', a=1) + instance._synthesizer.assert_called_once_with(table_meta, a=1) synthesizer._set_parameters.assert_called_once_with( instance._extract_parameters.return_value, {'colA': 'default_param', 'colB': 'default_param'} diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 289cf7519..197141e69 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -89,8 +89,7 @@ def test___init__(self, mock_check_metadata_updated, mock_data_processor, metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales, - table_name=None + locales=instance.locales ) metadata.validate.assert_called_once_with() mock_check_metadata_updated.assert_called_once() @@ -124,8 +123,7 @@ def test___init__custom(self, mock_data_processor): metadata=metadata, enforce_rounding=instance.enforce_rounding, enforce_min_max_values=instance.enforce_min_max_values, - locales=instance.locales, - table_name=None + locales=instance.locales ) metadata.validate.assert_called_once_with() @@ -184,8 +182,7 @@ def test_get_parameters(self, mock_data_processor): assert parameters == { 'enforce_min_max_values': False, 'enforce_rounding': False, - 'locales': 'en_CA', - 'table_name': None + 'locales': 'en_CA' } @patch('sdv.single_table.base.DataProcessor') @@ -362,8 +359,7 @@ def test_fit_processed_data(self, mock_datetime, caplog): instance = Mock( _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, - _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None + _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) @@ -394,7 +390,6 @@ def test_fit_processed_data_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, - table_name=None ) processed_data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True @@ -422,7 +417,6 @@ def test_fit(self, mock_datetime, caplog): _fitted_sdv_version=None, _fitted_sdv_enterprise_version=None, _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None ) data = pd.DataFrame({'column_a': [1, 2, 3], 'name': ['John', 'Doe', 'Johanna']}) instance._random_state_set = True @@ -459,7 +453,6 @@ def test_fit_raises_version_error(self): instance = Mock( _fitted_sdv_version='1.0.0', _fitted_sdv_enterprise_version=None, - table_name=None ) data = pd.DataFrame({'column_a': [1, 2, 3]}) instance._random_state_set = True @@ -1417,7 +1410,6 @@ def test_sample(self, mock_datetime, caplog): output_file_path = 'temp.csv' instance = Mock( _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None ) instance.get_metadata.return_value._constraints = False instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) @@ -1810,7 +1802,6 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): # Setup synthesizer = Mock( _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - table_name=None ) mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' @@ -1839,7 +1830,7 @@ def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_war mock_datetime, caplog): """Test that the ``load`` method loads a stored synthesizer.""" # Setup - synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None, table_name=None) + synthesizer_mock = Mock(_fitted=False, _synthesizer_id=None) mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' synthesizer_id = 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' mock_generate_synthesizer_id.return_value = synthesizer_id diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index 762e28f58..6909c86d2 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -174,8 +174,7 @@ def test_get_params(self): 'pac': 10, 'cuda': True, 'numerical_distributions': {}, - 'default_distribution': 'beta', - 'table_name': None + 'default_distribution': 'beta' } @patch('sdv.single_table.copulagan.rdt') diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 02ec24b14..3c96028c3 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -130,8 +130,7 @@ def test_get_parameters(self): 'enforce_rounding': True, 'locales': ['en_US'], 'numerical_distributions': {}, - 'default_distribution': 'beta', - 'table_name': None + 'default_distribution': 'beta' } @patch('sdv.single_table.copulas.LOGGER') diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index e18e27552..ddbdfc91c 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -150,8 +150,7 @@ def test_get_parameters(self): 'verbose': False, 'epochs': 300, 'pac': 10, - 'cuda': True, - 'table_name': None + 'cuda': True } def test__estimate_num_columns(self): @@ -426,8 +425,7 @@ def test_get_parameters(self): 'batch_size': 500, 'epochs': 300, 'loss_factor': 2, - 'cuda': True, - 'table_name': None + 'cuda': True } @patch('sdv.single_table.ctgan.TVAE')