From 9fa8fac4413cbd3717cc6e053db3485cd1b05b28 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 May 2024 11:10:41 -0500 Subject: [PATCH 01/26] WIP trying to convert everything to string before processing --- sdv/data_processing/data_processor.py | 1 + sdv/metadata/multi_table.py | 5 +++-- sdv/metadata/single_table.py | 3 ++- sdv/multi_table/base.py | 11 ++++++++--- sdv/single_table/base.py | 2 +- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 5c1bff886..b89b4cf7a 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -862,6 +862,7 @@ def reverse_transform(self, data, reset_keys=False): for column in self.metadata.columns.keys() - set(sampled_columns + self._keys) if self._hyper_transformer.field_transformers.get(column) ] + print(missing_columns) if missing_columns and num_rows: anonymized_data = self._hyper_transformer.create_anonymized_columns( num_rows=num_rows, diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index b7ccb7046..a9a002ac6 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1025,10 +1025,11 @@ def _set_metadata_dict(self, metadata): Python dictionary representing a ``MultiTableMetadata`` object. """ for table_name, table_dict in metadata.get('tables', {}).items(): - self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict) + self.tables[str(table_name)] = SingleTableMetadata.load_from_dict(table_dict) for relationship in metadata.get('relationships', []): - self.relationships.append(relationship) + type_safe_relationships = {key: str(value) if not isinstance(value, str) else value for key, value in relationship.items()} + self.relationships.append(type_safe_relationships) @classmethod def load_from_dict(cls, metadata_dict): diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 4f8b1db94..2109b71d0 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1238,7 +1238,8 @@ def load_from_dict(cls, metadata_dict): for key in instance._KEYS: value = deepcopy(metadata_dict.get(key)) if value: - setattr(instance, f'{key}', value) + type_safe_value = {str(key) if not isinstance(key, str) else key: value for key, value in value.items()} + setattr(instance, f'{key}', type_safe_value) return instance diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 114f40739..6da9212aa 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -417,9 +417,14 @@ def fit(self, data): Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format (before any transformations). """ + type_safe_data = {str(key) if not isinstance(key, str) else key: value for key, value in data.items()} total_rows = 0 total_columns = 0 - for table in data.values(): + for table, dataframe in type_safe_data.items(): + dataframe.columns = dataframe.columns.astype(str) + type_safe_data[table] = dataframe + + for table in type_safe_data.values(): total_rows += len(table) total_columns += len(table.columns) @@ -440,10 +445,10 @@ def fit(self, data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - _validate_foreign_keys_not_null(self.metadata, data) + _validate_foreign_keys_not_null(self.metadata, type_safe_data) self._check_metadata_updated() self._fitted = False - processed_data = self.preprocess(data) + processed_data = self.preprocess(type_safe_data) self._print(text='\n', end='') self.fit_processed_data(processed_data) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 60656a4a2..b0f53a81c 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -448,7 +448,7 @@ def fit(self, data): len(data.columns), self._synthesizer_id, ) - + data.columns = data.columns.astype(str) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False From 00ebb3a7c365c49faacf9a35099d169a0ad13337 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 May 2024 15:06:58 -0500 Subject: [PATCH 02/26] Fix test --- sdv/data_processing/data_processor.py | 10 ++++++++-- sdv/metadata/single_table.py | 5 +++-- sdv/single_table/base.py | 2 ++ tests/unit/version/test_version.py | 2 +- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index b89b4cf7a..b36a87549 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -841,8 +841,11 @@ def reverse_transform(self, data, reset_keys=False): for column in self._hyper_transformer._output_columns if column in data.columns ] - + print(f'All columns: {self._hyper_transformer._output_columns}') + print(f'columns in data: {data.columns}') + print(f'data into reverse_transformer: {data}') reversed_data = data + print(f'reversible_columns before : {reversible_columns}') try: if not data.empty: reversed_data = self._hyper_transformer.reverse_transform_subset( @@ -851,6 +854,8 @@ def reverse_transform(self, data, reset_keys=False): except rdt.errors.NotFittedError: LOGGER.info(f'HyperTransformer has not been fitted for table {self.table_name}') + print(f'reversed_data.columns : {reversed_data.columns}') + for transformer in self.grouped_columns_to_transformers.values(): if not transformer.output_columns: reversed_data = transformer.reverse_transform(reversed_data) @@ -862,7 +867,8 @@ def reverse_transform(self, data, reset_keys=False): for column in self.metadata.columns.keys() - set(sampled_columns + self._keys) if self._hyper_transformer.field_transformers.get(column) ] - print(missing_columns) + for col in missing_columns: + print(f'Col: {col} : {self._hyper_transformer.field_transformers.get(col)}') if missing_columns and num_rows: anonymized_data = self._hyper_transformer.create_anonymized_columns( num_rows=num_rows, diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 2109b71d0..fc3e4c78e 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1238,8 +1238,9 @@ def load_from_dict(cls, metadata_dict): for key in instance._KEYS: value = deepcopy(metadata_dict.get(key)) if value: - type_safe_value = {str(key) if not isinstance(key, str) else key: value for key, value in value.items()} - setattr(instance, f'{key}', type_safe_value) + if key == 'columns': + value = {str(key) if not isinstance(key, str) else key: col for key, col in value.items()} + setattr(instance, f'{key}', value) return instance diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index b0f53a81c..aa4e186e1 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -606,10 +606,12 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, self._set_random_state(FIXED_RNG_SEED) need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns + print(f'need_sample: {need_sample}') if self._model and need_sample: if conditions is None: raw_sampled = self._sample(num_rows) + print(f'sampled rows: {raw_sampled.columns}') else: try: raw_sampled = self._sample(num_rows, transformed_conditions) diff --git a/tests/unit/version/test_version.py b/tests/unit/version/test_version.py index 201bbbe52..17b4dff3e 100644 --- a/tests/unit/version/test_version.py +++ b/tests/unit/version/test_version.py @@ -4,5 +4,5 @@ def test_sdv_versions(): """Test version for SDV.""" assert sdv.version.__all__ == ('public', 'enterprise') - assert sdv.version.public == sdv.__version__ + # assert sdv.version.public == sdv.__version__ assert sdv.version.enterprise is None From 68f876141be3ba29ad1014ca48b109986608285e Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 May 2024 16:55:46 -0500 Subject: [PATCH 03/26] Remove all prints --- sdv/data_processing/data_processor.py | 8 -------- sdv/single_table/base.py | 2 -- sdv/single_table/utils.py | 2 +- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index b36a87549..630f7deab 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -841,11 +841,7 @@ def reverse_transform(self, data, reset_keys=False): for column in self._hyper_transformer._output_columns if column in data.columns ] - print(f'All columns: {self._hyper_transformer._output_columns}') - print(f'columns in data: {data.columns}') - print(f'data into reverse_transformer: {data}') reversed_data = data - print(f'reversible_columns before : {reversible_columns}') try: if not data.empty: reversed_data = self._hyper_transformer.reverse_transform_subset( @@ -854,8 +850,6 @@ def reverse_transform(self, data, reset_keys=False): except rdt.errors.NotFittedError: LOGGER.info(f'HyperTransformer has not been fitted for table {self.table_name}') - print(f'reversed_data.columns : {reversed_data.columns}') - for transformer in self.grouped_columns_to_transformers.values(): if not transformer.output_columns: reversed_data = transformer.reverse_transform(reversed_data) @@ -867,8 +861,6 @@ def reverse_transform(self, data, reset_keys=False): for column in self.metadata.columns.keys() - set(sampled_columns + self._keys) if self._hyper_transformer.field_transformers.get(column) ] - for col in missing_columns: - print(f'Col: {col} : {self._hyper_transformer.field_transformers.get(col)}') if missing_columns and num_rows: anonymized_data = self._hyper_transformer.create_anonymized_columns( num_rows=num_rows, diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index aa4e186e1..b0f53a81c 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -606,12 +606,10 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, self._set_random_state(FIXED_RNG_SEED) need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns - print(f'need_sample: {need_sample}') if self._model and need_sample: if conditions is None: raw_sampled = self._sample(num_rows) - print(f'sampled rows: {raw_sampled.columns}') else: try: raw_sampled = self._sample(num_rows, transformed_conditions) diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 11043ca2f..d9bdb72d7 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -303,7 +303,7 @@ def unflatten_dict(flat): else: subdict = unflattened.setdefault(key, {}) - if subkey.isdigit(): + if subkey.isdigit() and key != 'univariates': subkey = int(subkey) inner = subdict.setdefault(subkey, {}) From 4d89c583542bbc4bbd6374239cf6cd7d586ea8ed Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 May 2024 17:00:39 -0500 Subject: [PATCH 04/26] Fix lint --- sdv/metadata/multi_table.py | 6 +++++- sdv/metadata/single_table.py | 6 +++++- sdv/multi_table/base.py | 4 +++- tests/unit/version/test_version.py | 2 +- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index a9a002ac6..587c168a9 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1028,7 +1028,11 @@ def _set_metadata_dict(self, metadata): self.tables[str(table_name)] = SingleTableMetadata.load_from_dict(table_dict) for relationship in metadata.get('relationships', []): - type_safe_relationships = {key: str(value) if not isinstance(value, str) else value for key, value in relationship.items()} + type_safe_relationships = { + key: str(value) + if not isinstance(value, str) + else value for key, value in relationship.items() + } self.relationships.append(type_safe_relationships) @classmethod diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index fc3e4c78e..806bc3561 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1239,7 +1239,11 @@ def load_from_dict(cls, metadata_dict): value = deepcopy(metadata_dict.get(key)) if value: if key == 'columns': - value = {str(key) if not isinstance(key, str) else key: col for key, col in value.items()} + value = { + str(key) + if not isinstance(key, str) + else key: col for key, col in value.items() + } setattr(instance, f'{key}', value) return instance diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 6da9212aa..0a5eb584b 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -417,7 +417,9 @@ def fit(self, data): Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format (before any transformations). """ - type_safe_data = {str(key) if not isinstance(key, str) else key: value for key, value in data.items()} + type_safe_data = { + str(key) if not isinstance(key, str) else key: value for key, value in data.items() + } total_rows = 0 total_columns = 0 for table, dataframe in type_safe_data.items(): diff --git a/tests/unit/version/test_version.py b/tests/unit/version/test_version.py index 17b4dff3e..201bbbe52 100644 --- a/tests/unit/version/test_version.py +++ b/tests/unit/version/test_version.py @@ -4,5 +4,5 @@ def test_sdv_versions(): """Test version for SDV.""" assert sdv.version.__all__ == ('public', 'enterprise') - # assert sdv.version.public == sdv.__version__ + assert sdv.version.public == sdv.__version__ assert sdv.version.enterprise is None From d881306eb5709b3a91f42beba9f9921c426a7c59 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 May 2024 17:01:15 -0500 Subject: [PATCH 05/26] Fix lint --- sdv/data_processing/data_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 630f7deab..5c1bff886 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -841,6 +841,7 @@ def reverse_transform(self, data, reset_keys=False): for column in self._hyper_transformer._output_columns if column in data.columns ] + reversed_data = data try: if not data.empty: From 79b685358c5357c97c4d67800f13432f813b2519 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 1 May 2024 17:44:46 -0500 Subject: [PATCH 06/26] Add tests --- tests/integration/multi_table/test_hma.py | 45 +++++++++++++++++++++ tests/integration/single_table/test_base.py | 34 ++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index f7c499d28..564d59042 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1738,3 +1738,48 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): ' Total number of columns: 15\n' ' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' ) + + +def test_fit_and_sample_numerical_col_names(): + """Test fit and sampling when column names are integers""" + # Setup data + num_rows = 50 + num_cols = 10 + num_tables = 2 + data = {} + for i in range(num_tables): + values = {j: np.random.randint(0, 100, size=num_rows) for j in range(num_cols)} + data[i] = pd.DataFrame(values) + + primary_key = pd.DataFrame({1: range(num_rows)}) + primary_key_2 = pd.DataFrame({2: range(num_rows)}) + data[0][1] = primary_key + data[1][1] = primary_key + data[1][2] = primary_key_2 + metadata = MultiTableMetadata() + metadata_dict = {'tables': {}} + for table_idx in range(num_tables): + metadata_dict['tables'][table_idx] = {'columns': {}} + for i in range(num_cols): + metadata_dict['tables'][table_idx]['columns'][i] = {'sdtype': 'numerical'} + metadata_dict['tables'][0]['columns'][1] = {'sdtype': 'id'} + metadata_dict['tables'][1]['columns'][2] = {'sdtype': 'id'} + metadata_dict['relationships'] = [ + { + 'parent_table_name': 0, + 'parent_primary_key': 1, + 'child_table_name': 1, + 'child_foreign_key': 2 + } + ] + metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata.set_primary_key('0', '1') + + # Run + synth = HMASynthesizer(metadata) + synth.fit(data) + first_sample = synth.sample() + second_sample = synth.sample() + + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(first_sample['0'], second_sample['0']) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 8c7ea2601..40ef26ba7 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -855,3 +855,37 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): ' Total number of columns: 3\n' ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' ) + + +def test_fit_and_sample_numerical_col_names(): + """Test fit and sampling when column names are integers""" + # Setup + num_rows = 50 + num_cols = 10 + values = { + i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols) + } + data = pd.DataFrame(values) + metadata = SingleTableMetadata() + metadata_dict = {'columns': {}} + for i in range(num_cols): + metadata_dict['columns'][i] = {'sdtype': 'numerical'} + metadata = SingleTableMetadata.load_from_dict(metadata_dict) + + # Run + + synthesizers = [ + CTGANSynthesizer, + TVAESynthesizer, + GaussianCopulaSynthesizer, + CopulaGANSynthesizer + ] + for synthesizer_class in synthesizers: + synth = synthesizer_class(metadata) + synth.fit(data) + sample_1 = synth.sample(10) + sample_2 = synth.sample(10) + + # Assert + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(sample_1, sample_2) From 67df9eb911dc018a71fe43ff28bc44a56f8b9e5e Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 11:26:18 -0500 Subject: [PATCH 07/26] Update Test --- sdv/data_processing/data_processor.py | 1 + sdv/metadata/single_table.py | 6 ++-- sdv/single_table/base.py | 11 +++++-- tests/integration/single_table/test_base.py | 35 +++++++++++---------- tests/unit/single_table/test_base.py | 1 + 5 files changed, 34 insertions(+), 20 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 5c1bff886..da0494ac7 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -755,6 +755,7 @@ def fit(self, data): if data.empty: raise ValueError('The fit dataframe is empty, synthesizer will not be fitted.') self._prepared_for_fitting = False + print(f'Data: {data.columns}') self.prepare_for_fitting(data) constrained = self._transform_constraints(data) if constrained.empty: diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 806bc3561..4414158b4 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -964,9 +964,10 @@ def validate(self): + '\n'.join([str(e) for e in errors]) ) - def _validate_metadata_matches_data(self, columns): + def _validate_metadata_matches_data(self, data_columns): errors = [] metadata_columns = self.columns or {} + columns = data_columns.astype(str) missing_data_columns = set(columns).difference(metadata_columns) if missing_data_columns: errors.append( @@ -1055,7 +1056,8 @@ def _validate_column_data(self, column, sdtype_warnings): list: A list containing any validation error messages found during the process. """ - column_metadata = self.columns[column.name] + column_metadata = self.columns[str(column.name)] + print(column_metadata) sdtype = column_metadata['sdtype'] invalid_values = None diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index b0f53a81c..9a462028b 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -106,6 +106,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._random_state_set = False self._update_default_transformers() self._creation_date = datetime.datetime.today().strftime('%Y-%m-%d') + self._original_columns = None self._fitted_date = None self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None @@ -176,6 +177,8 @@ def validate(self, data): * context columns vary for a sequence key * values of a column don't satisfy their sdtype """ + # self._original_columns = data.columns + # data.columns = data.columns.astype(str) self._validate_metadata(data) self._validate_constraints(data) @@ -184,6 +187,7 @@ def validate(self, data): synthesizer_errors = self._validate(data) # Validate rules specific to each synthesizer if synthesizer_errors: raise InvalidDataError(synthesizer_errors) + # data.columns = self._original_columns def _validate_transformers(self, column_name_to_transformer): primary_and_alternate_keys = self.metadata._get_primary_and_alternate_keys() @@ -416,7 +420,8 @@ def fit_processed_data(self, processed_data): len(processed_data.columns), self._synthesizer_id, ) - + self._original_columns = processed_data.columns + processed_data.columns = processed_data.columns.astype(str) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: self._fit(processed_data) @@ -448,7 +453,7 @@ def fit(self, data): len(data.columns), self._synthesizer_id, ) - data.columns = data.columns.astype(str) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False @@ -884,6 +889,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) + sampled_data.columns = self._original_columns + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 40ef26ba7..c4467e830 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -857,7 +857,15 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): ) -def test_fit_and_sample_numerical_col_names(): +SYNTHESIZERS_CLASSES = [ + pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'), + pytest.param(TVAESynthesizer, id='TVAESynthesizer'), + pytest.param(GaussianCopulaSynthesizer, id='GaussianCopulaSynthesizer'), + pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'), +] + +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_fit_and_sample_numerical_col_names(synthesizer_class): """Test fit and sampling when column names are integers""" # Setup num_rows = 50 @@ -873,19 +881,14 @@ def test_fit_and_sample_numerical_col_names(): metadata = SingleTableMetadata.load_from_dict(metadata_dict) # Run + synth = synthesizer_class(metadata) + synth.fit(data) + sample_1 = synth.sample(10) + sample_2 = synth.sample(10) - synthesizers = [ - CTGANSynthesizer, - TVAESynthesizer, - GaussianCopulaSynthesizer, - CopulaGANSynthesizer - ] - for synthesizer_class in synthesizers: - synth = synthesizer_class(metadata) - synth.fit(data) - sample_1 = synth.sample(10) - sample_2 = synth.sample(10) - - # Assert - with pytest.raises(AssertionError): - pd.testing.assert_frame_equal(sample_1, sample_2) + assert sample_1.columns.tolist() == data.columns.tolist() + assert sample_2.columns.tolist() == data.columns.tolist() + + # Assert + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(sample_1, sample_2) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 197141e69..bf667717f 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1410,6 +1410,7 @@ def test_sample(self, mock_datetime, caplog): output_file_path = 'temp.csv' instance = Mock( _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', + _original_columns=pd.Index(['col']) ) instance.get_metadata.return_value._constraints = False instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) From 6f8c497175279989a3fdc9ab964aa0ff8cced794 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 12:05:28 -0500 Subject: [PATCH 08/26] Merge --- sdv/metadata/single_table.py | 10 +++++++--- sdv/multi_table/hma.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 4414158b4..8d3c72a98 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -618,7 +618,7 @@ def detect_from_csv(self, filepath, read_csv_parameters=None): @staticmethod def _validate_key_datatype(column_name): """Check whether column_name is a string.""" - return isinstance(column_name, str) + return isinstance(column_name, str) or isinstance(column_name, int) def _validate_keys_sdtype(self, keys, key_type): """Validate that each key is of type 'id' or a valid Faker function.""" @@ -638,9 +638,13 @@ def _validate_key(self, column_name, key_type): if column_name is not None: if not self._validate_key_datatype(column_name): raise InvalidMetadataError( - f"'{key_type}_key' must be a string.") + f"'{key_type}_key' must be a string or integer.") - keys = {column_name} if isinstance(column_name, str) else set(column_name) + keys = ( + {column_name} + if isinstance(column_name, str) or isinstance(column_name, int) + else set(column_name) + ) invalid_ids = keys - set(self.columns) if invalid_ids: raise InvalidMetadataError( diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index a769c3e13..83b8a7524 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -494,7 +494,7 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): parent child relationship. """ prefix = f'__{table_name}__{foreign_key}__' - keys = [key for key in parent_row.keys() if key.startswith(prefix)] + keys = [key for key in parent_row.keys() if str(key).startswith(prefix)] new_keys = {key: key[len(prefix):] for key in keys} flat_parameters = parent_row[keys].astype(float).fillna(1e-6) From c64da83de5191da2ea812ef51f22d01d0f2f691f Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 12:08:47 -0500 Subject: [PATCH 09/26] Remove string changes --- sdv/data_processing/data_processor.py | 1 - sdv/metadata/multi_table.py | 9 ++------- sdv/metadata/single_table.py | 7 ------- sdv/multi_table/base.py | 13 +++---------- sdv/single_table/base.py | 7 ------- 5 files changed, 5 insertions(+), 32 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index da0494ac7..5c1bff886 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -755,7 +755,6 @@ def fit(self, data): if data.empty: raise ValueError('The fit dataframe is empty, synthesizer will not be fitted.') self._prepared_for_fitting = False - print(f'Data: {data.columns}') self.prepare_for_fitting(data) constrained = self._transform_constraints(data) if constrained.empty: diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index d9bb59560..c7aa10d30 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1025,15 +1025,10 @@ def _set_metadata_dict(self, metadata): Python dictionary representing a ``MultiTableMetadata`` object. """ for table_name, table_dict in metadata.get('tables', {}).items(): - self.tables[str(table_name)] = SingleTableMetadata.load_from_dict(table_dict) + self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict) for relationship in metadata.get('relationships', []): - type_safe_relationships = { - key: str(value) - if not isinstance(value, str) - else value for key, value in relationship.items() - } - self.relationships.append(type_safe_relationships) + self.relationships.append(relationship) @classmethod def load_from_dict(cls, metadata_dict): diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index 8d3c72a98..d810a577d 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1061,7 +1061,6 @@ def _validate_column_data(self, column, sdtype_warnings): A list containing any validation error messages found during the process. """ column_metadata = self.columns[str(column.name)] - print(column_metadata) sdtype = column_metadata['sdtype'] invalid_values = None @@ -1244,12 +1243,6 @@ def load_from_dict(cls, metadata_dict): for key in instance._KEYS: value = deepcopy(metadata_dict.get(key)) if value: - if key == 'columns': - value = { - str(key) - if not isinstance(key, str) - else key: col for key, col in value.items() - } setattr(instance, f'{key}', value) return instance diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 0a5eb584b..114f40739 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -417,16 +417,9 @@ def fit(self, data): Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format (before any transformations). """ - type_safe_data = { - str(key) if not isinstance(key, str) else key: value for key, value in data.items() - } total_rows = 0 total_columns = 0 - for table, dataframe in type_safe_data.items(): - dataframe.columns = dataframe.columns.astype(str) - type_safe_data[table] = dataframe - - for table in type_safe_data.values(): + for table in data.values(): total_rows += len(table) total_columns += len(table.columns) @@ -447,10 +440,10 @@ def fit(self, data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - _validate_foreign_keys_not_null(self.metadata, type_safe_data) + _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() self._fitted = False - processed_data = self.preprocess(type_safe_data) + processed_data = self.preprocess(data) self._print(text='\n', end='') self.fit_processed_data(processed_data) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 9a462028b..97a4a3b2a 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -106,7 +106,6 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, self._random_state_set = False self._update_default_transformers() self._creation_date = datetime.datetime.today().strftime('%Y-%m-%d') - self._original_columns = None self._fitted_date = None self._fitted_sdv_version = None self._fitted_sdv_enterprise_version = None @@ -177,8 +176,6 @@ def validate(self, data): * context columns vary for a sequence key * values of a column don't satisfy their sdtype """ - # self._original_columns = data.columns - # data.columns = data.columns.astype(str) self._validate_metadata(data) self._validate_constraints(data) @@ -187,7 +184,6 @@ def validate(self, data): synthesizer_errors = self._validate(data) # Validate rules specific to each synthesizer if synthesizer_errors: raise InvalidDataError(synthesizer_errors) - # data.columns = self._original_columns def _validate_transformers(self, column_name_to_transformer): primary_and_alternate_keys = self.metadata._get_primary_and_alternate_keys() @@ -420,7 +416,6 @@ def fit_processed_data(self, processed_data): len(processed_data.columns), self._synthesizer_id, ) - self._original_columns = processed_data.columns processed_data.columns = processed_data.columns.astype(str) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: @@ -889,8 +884,6 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) - sampled_data.columns = self._original_columns - SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' From 8d617bd5bbf264d483fc3150c61696b6bbad6600 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 12:09:45 -0500 Subject: [PATCH 10/26] remove old code --- sdv/single_table/base.py | 1 - tests/unit/single_table/test_base.py | 1 - 2 files changed, 2 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 97a4a3b2a..7469b3c18 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -416,7 +416,6 @@ def fit_processed_data(self, processed_data): len(processed_data.columns), self._synthesizer_id, ) - processed_data.columns = processed_data.columns.astype(str) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: self._fit(processed_data) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index bf667717f..197141e69 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1410,7 +1410,6 @@ def test_sample(self, mock_datetime, caplog): output_file_path = 'temp.csv' instance = Mock( _synthesizer_id='BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', - _original_columns=pd.Index(['col']) ) instance.get_metadata.return_value._constraints = False instance._sample_with_progress_bar.return_value = pd.DataFrame({'col': [1, 2, 3]}) From 8982a70784367370eda74b05a76aea79263ad6db Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 12:25:14 -0500 Subject: [PATCH 11/26] Remove str conversions --- sdv/metadata/single_table.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index d810a577d..fa65b27c8 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -968,10 +968,9 @@ def validate(self): + '\n'.join([str(e) for e in errors]) ) - def _validate_metadata_matches_data(self, data_columns): + def _validate_metadata_matches_data(self, columns): errors = [] metadata_columns = self.columns or {} - columns = data_columns.astype(str) missing_data_columns = set(columns).difference(metadata_columns) if missing_data_columns: errors.append( @@ -1060,7 +1059,7 @@ def _validate_column_data(self, column, sdtype_warnings): list: A list containing any validation error messages found during the process. """ - column_metadata = self.columns[str(column.name)] + column_metadata = self.columns[column.name] sdtype = column_metadata['sdtype'] invalid_values = None From 099dad9ea2cba7bc10567372ad00eb9732eff358 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 13:35:23 -0500 Subject: [PATCH 12/26] Debug --- sdv/single_table/base.py | 12 +++++- .../data_processing/test_data_processor.py | 2 + tests/integration/multi_table/test_hma.py | 5 ++- tests/integration/single_table/test_base.py | 37 ++++++++++--------- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index b0f53a81c..3da732eb4 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -102,6 +102,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, enforce_min_max_values=self.enforce_min_max_values, locales=self.locales, ) + self._original_columns = pd.Index([]) self._fitted = False self._random_state_set = False self._update_default_transformers() @@ -362,6 +363,11 @@ def get_info(self): return info def _preprocess(self, data): + # for column in data.columns: + # if isinstance(column, int): + # self._original_columns = data.columns + # data.columns = data.columns.astype(str) + # break self.validate(data) self._data_processor.fit(data) return self._data_processor.transform(data) @@ -448,7 +454,6 @@ def fit(self, data): len(data.columns), self._synthesizer_id, ) - data.columns = data.columns.astype(str) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False @@ -456,6 +461,8 @@ def fit(self, data): self._random_state_set = False processed_data = self._preprocess(data) self.fit_processed_data(processed_data) + if not self._original_columns.empty: + data.columns = self._original_columns def save(self, filepath): """Save this model instance to the given path using cloudpickle. @@ -884,6 +891,9 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) + if not self._original_columns.empty: + sampled_data.columns = self._original_columns + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index 00f933f39..6db5b71c6 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -269,7 +269,9 @@ def test_prepare_for_fitting(self): 'degree_perc': FloatFormatter } for column_name, transformer_class in expected_transformers.items(): + print(f'Transformer Class: {transformer_class}, {column_name}={field_transformers[column_name]}') if transformer_class is not None: + print("Before Check") assert isinstance(field_transformers[column_name], transformer_class) else: assert field_transformers[column_name] is None diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 564d59042..b51d82ea3 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1741,7 +1741,7 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): def test_fit_and_sample_numerical_col_names(): - """Test fit and sampling when column names are integers""" + """Test fitting/sampling when column names are integers""" # Setup data num_rows = 50 num_cols = 10 @@ -1780,6 +1780,9 @@ def test_fit_and_sample_numerical_col_names(): synth.fit(data) first_sample = synth.sample() second_sample = synth.sample() + assert first_sample.columns.tolist() == data.columns.tolist() + assert second_sample.columns.tolist() == data.columns.tolist() + # Assert with pytest.raises(AssertionError): pd.testing.assert_frame_equal(first_sample['0'], second_sample['0']) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 40ef26ba7..6427c509b 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -857,8 +857,16 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): ) -def test_fit_and_sample_numerical_col_names(): - """Test fit and sampling when column names are integers""" +SYNTHESIZERS_CLASSES = [ + pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'), + pytest.param(TVAESynthesizer, id='TVAESynthesizer'), + pytest.param(GaussianCopulaSynthesizer, id='GaussianCopulaSynthesizer'), + pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'), +] + +@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) +def test_fit_and_sample_numerical_col_names(synthesizer_class): + """Test fitting/sampling when column names are integers""" # Setup num_rows = 50 num_cols = 10 @@ -873,19 +881,14 @@ def test_fit_and_sample_numerical_col_names(): metadata = SingleTableMetadata.load_from_dict(metadata_dict) # Run + synth = synthesizer_class(metadata) + synth.fit(data) + sample_1 = synth.sample(10) + sample_2 = synth.sample(10) - synthesizers = [ - CTGANSynthesizer, - TVAESynthesizer, - GaussianCopulaSynthesizer, - CopulaGANSynthesizer - ] - for synthesizer_class in synthesizers: - synth = synthesizer_class(metadata) - synth.fit(data) - sample_1 = synth.sample(10) - sample_2 = synth.sample(10) - - # Assert - with pytest.raises(AssertionError): - pd.testing.assert_frame_equal(sample_1, sample_2) + assert sample_1.columns.tolist() == data.columns.tolist() + assert sample_2.columns.tolist() == data.columns.tolist() + + # Assert + with pytest.raises(AssertionError): + pd.testing.assert_frame_equal(sample_1, sample_2) From 9ec0b15f10c05d6d8c04694405cefa65cdd38a38 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 13:58:30 -0500 Subject: [PATCH 13/26] Fixed processing --- sdv/single_table/base.py | 12 ++++++------ tests/unit/single_table/test_base.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 3da732eb4..faad09648 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -363,11 +363,6 @@ def get_info(self): return info def _preprocess(self, data): - # for column in data.columns: - # if isinstance(column, int): - # self._original_columns = data.columns - # data.columns = data.columns.astype(str) - # break self.validate(data) self._data_processor.fit(data) return self._data_processor.transform(data) @@ -389,6 +384,11 @@ def preprocess(self, data): "please refit the model using 'fit' or 'fit_processed_data'." ) + for column in data.columns: + if isinstance(column, int): + self._original_columns = data.columns + data.columns = data.columns.astype(str) + break return self._preprocess(data) def _fit(self, processed_data): @@ -459,7 +459,7 @@ def fit(self, data): self._fitted = False self._data_processor.reset_sampling() self._random_state_set = False - processed_data = self._preprocess(data) + processed_data = self.preprocess(data) self.fit_processed_data(processed_data) if not self._original_columns.empty: data.columns = self._original_columns diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 197141e69..abaa7c791 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -429,8 +429,8 @@ def test_fit(self, mock_datetime, caplog): # Assert assert instance._random_state_set is False instance._data_processor.reset_sampling.assert_called_once_with() - instance._preprocess.assert_called_once_with(data) - instance.fit_processed_data.assert_called_once_with(instance._preprocess.return_value) + instance.preprocess.assert_called_once_with(data) + instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once() assert caplog.messages[0] == ( '\nFit:\n' From f821a95501da2c2a6b884129f3474638890d16f9 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 14:53:07 -0500 Subject: [PATCH 14/26] Fix tests --- sdv/_utils.py | 2 ++ sdv/multi_table/base.py | 34 +++++++++++++++-------- tests/integration/multi_table/test_hma.py | 19 +++++++------ 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/sdv/_utils.py b/sdv/_utils.py index 577600b8d..5db3955ec 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -214,6 +214,8 @@ def _validate_foreign_keys_not_null(metadata, data): invalid_tables = defaultdict(list) for table_name, table_data in data.items(): for foreign_key in metadata._get_all_foreign_keys(table_name): + if foreign_key not in table_data and int(foreign_key) in table_data: + foreign_key = int(foreign_key) if table_data[foreign_key].isna().any(): invalid_tables[table_name].append(foreign_key) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 0a5eb584b..6d376319b 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -99,6 +99,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self.extended_columns = defaultdict(dict) self._table_synthesizers = {} self._table_parameters = defaultdict(dict) + self._original_table_columns = {} if synthesizer_kwargs is not None: warn_message = ( 'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not ' @@ -325,17 +326,26 @@ def update_transformers(self, table_name, column_name_to_transformer): self._validate_table_name(table_name) self._table_synthesizers[table_name].update_transformers(column_name_to_transformer) - def preprocess(self, data): + def preprocess(self, unprocessed_data): """Transform the raw data to numerical space. Args: - data (dict): + unprocessed_data (dict): Dictionary mapping each table name to a ``pandas.DataFrame``. Returns: dict: A dictionary with the preprocessed data. """ + data = { + str(key) if not isinstance(key, str) else key: value for key, value in unprocessed_data.items() + } + + for table, dataframe in data.items(): + self._original_table_columns[table] = dataframe.columns + dataframe.columns = dataframe.columns.astype(str) + data[table] = dataframe + self.validate(data) if self._fitted: warnings.warn( @@ -350,6 +360,9 @@ def preprocess(self, data): self._assign_table_transformers(synthesizer, table_name, table_data) processed_data[table_name] = synthesizer._preprocess(table_data) + for table, dataframe in data.items(): + dataframe.columns = self._original_table_columns[table] + return processed_data def _model_tables(self, augmented_data): @@ -417,16 +430,9 @@ def fit(self, data): Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format (before any transformations). """ - type_safe_data = { - str(key) if not isinstance(key, str) else key: value for key, value in data.items() - } total_rows = 0 total_columns = 0 - for table, dataframe in type_safe_data.items(): - dataframe.columns = dataframe.columns.astype(str) - type_safe_data[table] = dataframe - - for table in type_safe_data.values(): + for table in data.values(): total_rows += len(table) total_columns += len(table.columns) @@ -447,10 +453,10 @@ def fit(self, data): self._synthesizer_id, ) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - _validate_foreign_keys_not_null(self.metadata, type_safe_data) + _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() self._fitted = False - processed_data = self.preprocess(type_safe_data) + processed_data = self.preprocess(data) self._print(text='\n', end='') self.fit_processed_data(processed_data) @@ -487,6 +493,10 @@ def sample(self, scale=1.0): total_rows += len(table) total_columns += len(table.columns) + for table in sampled_data: + if table in self._original_table_columns: + sampled_data[table].columns = self._original_table_columns[table] + SYNTHESIZER_LOGGER.info( '\nSample:\n' ' Timestamp: %s\n' diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index b51d82ea3..0173340cf 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1749,13 +1749,13 @@ def test_fit_and_sample_numerical_col_names(): data = {} for i in range(num_tables): values = {j: np.random.randint(0, 100, size=num_rows) for j in range(num_cols)} - data[i] = pd.DataFrame(values) + data[str(i)] = pd.DataFrame(values) primary_key = pd.DataFrame({1: range(num_rows)}) primary_key_2 = pd.DataFrame({2: range(num_rows)}) - data[0][1] = primary_key - data[1][1] = primary_key - data[1][2] = primary_key_2 + data['0'][1] = primary_key + data['1'][1] = primary_key + data['1'][2] = primary_key_2 metadata = MultiTableMetadata() metadata_dict = {'tables': {}} for table_idx in range(num_tables): @@ -1766,9 +1766,9 @@ def test_fit_and_sample_numerical_col_names(): metadata_dict['tables'][1]['columns'][2] = {'sdtype': 'id'} metadata_dict['relationships'] = [ { - 'parent_table_name': 0, + 'parent_table_name': '0', 'parent_primary_key': 1, - 'child_table_name': 1, + 'child_table_name': '1', 'child_foreign_key': 2 } ] @@ -1780,9 +1780,10 @@ def test_fit_and_sample_numerical_col_names(): synth.fit(data) first_sample = synth.sample() second_sample = synth.sample() - assert first_sample.columns.tolist() == data.columns.tolist() - assert second_sample.columns.tolist() == data.columns.tolist() - + assert first_sample['0'].columns.tolist() == data['0'].columns.tolist() + assert first_sample['1'].columns.tolist() == data['1'].columns.tolist() + assert second_sample['0'].columns.tolist() == data['0'].columns.tolist() + assert second_sample['1'].columns.tolist() == data['1'].columns.tolist() # Assert with pytest.raises(AssertionError): pd.testing.assert_frame_equal(first_sample['0'], second_sample['0']) From 9f521f8bbc607447a8f9efa14f102d0c6c144a52 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 14:54:17 -0500 Subject: [PATCH 15/26] Remove lint --- sdv/multi_table/base.py | 5 ++++- tests/integration/data_processing/test_data_processor.py | 2 -- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 6d376319b..d135fe4ae 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -338,7 +338,10 @@ def preprocess(self, unprocessed_data): A dictionary with the preprocessed data. """ data = { - str(key) if not isinstance(key, str) else key: value for key, value in unprocessed_data.items() + str(key) + if not isinstance(key, str) + else key: value + for key, value in unprocessed_data.items() } for table, dataframe in data.items(): diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index f0dd40cbc..358415625 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -269,9 +269,7 @@ def test_prepare_for_fitting(self): 'degree_perc': FloatFormatter } for column_name, transformer_class in expected_transformers.items(): - print(f'Transformer Class: {transformer_class}, {column_name}={field_transformers[column_name]}') if transformer_class is not None: - print("Before Check") assert isinstance(field_transformers[column_name], transformer_class) else: assert field_transformers[column_name] is None From 971f685f7b9a6652ab23922591fe6aaef488be17 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 14:55:16 -0500 Subject: [PATCH 16/26] Fix --- tests/integration/single_table/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 6427c509b..f478cae68 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -864,6 +864,7 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'), ] + @pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES) def test_fit_and_sample_numerical_col_names(synthesizer_class): """Test fitting/sampling when column names are integers""" From 9abb32b7bcbd49f12b461a94b6747e20e45066e4 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 15:21:06 -0500 Subject: [PATCH 17/26] Fix merge --- sdv/metadata/multi_table.py | 7 ++++++- sdv/metadata/single_table.py | 16 +++++++++------- tests/integration/multi_table/test_hma.py | 8 ++++---- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index c7aa10d30..8aebf65c5 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -1028,7 +1028,12 @@ def _set_metadata_dict(self, metadata): self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict) for relationship in metadata.get('relationships', []): - self.relationships.append(relationship) + type_safe_relationships = { + key: str(value) + if not isinstance(value, str) + else value for key, value in relationship.items() + } + self.relationships.append(type_safe_relationships) @classmethod def load_from_dict(cls, metadata_dict): diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index fa65b27c8..806bc3561 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -618,7 +618,7 @@ def detect_from_csv(self, filepath, read_csv_parameters=None): @staticmethod def _validate_key_datatype(column_name): """Check whether column_name is a string.""" - return isinstance(column_name, str) or isinstance(column_name, int) + return isinstance(column_name, str) def _validate_keys_sdtype(self, keys, key_type): """Validate that each key is of type 'id' or a valid Faker function.""" @@ -638,13 +638,9 @@ def _validate_key(self, column_name, key_type): if column_name is not None: if not self._validate_key_datatype(column_name): raise InvalidMetadataError( - f"'{key_type}_key' must be a string or integer.") + f"'{key_type}_key' must be a string.") - keys = ( - {column_name} - if isinstance(column_name, str) or isinstance(column_name, int) - else set(column_name) - ) + keys = {column_name} if isinstance(column_name, str) else set(column_name) invalid_ids = keys - set(self.columns) if invalid_ids: raise InvalidMetadataError( @@ -1242,6 +1238,12 @@ def load_from_dict(cls, metadata_dict): for key in instance._KEYS: value = deepcopy(metadata_dict.get(key)) if value: + if key == 'columns': + value = { + str(key) + if not isinstance(key, str) + else key: col for key, col in value.items() + } setattr(instance, f'{key}', value) return instance diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 0173340cf..908be40ec 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1759,11 +1759,11 @@ def test_fit_and_sample_numerical_col_names(): metadata = MultiTableMetadata() metadata_dict = {'tables': {}} for table_idx in range(num_tables): - metadata_dict['tables'][table_idx] = {'columns': {}} + metadata_dict['tables'][str(table_idx)] = {'columns': {}} for i in range(num_cols): - metadata_dict['tables'][table_idx]['columns'][i] = {'sdtype': 'numerical'} - metadata_dict['tables'][0]['columns'][1] = {'sdtype': 'id'} - metadata_dict['tables'][1]['columns'][2] = {'sdtype': 'id'} + metadata_dict['tables'][str(table_idx)]['columns'][i] = {'sdtype': 'numerical'} + metadata_dict['tables']['0']['columns'][1] = {'sdtype': 'id'} + metadata_dict['tables']['1']['columns'][2] = {'sdtype': 'id'} metadata_dict['relationships'] = [ { 'parent_table_name': '0', From 3281badeabc1fc9e48fa0a23e999f2ceb1b039fe Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 15:32:26 -0500 Subject: [PATCH 18/26] Remove table transformation --- sdv/multi_table/base.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index d135fe4ae..8dfff0967 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -326,24 +326,17 @@ def update_transformers(self, table_name, column_name_to_transformer): self._validate_table_name(table_name) self._table_synthesizers[table_name].update_transformers(column_name_to_transformer) - def preprocess(self, unprocessed_data): + def preprocess(self, data): """Transform the raw data to numerical space. Args: - unprocessed_data (dict): + data (dict): Dictionary mapping each table name to a ``pandas.DataFrame``. Returns: dict: A dictionary with the preprocessed data. """ - data = { - str(key) - if not isinstance(key, str) - else key: value - for key, value in unprocessed_data.items() - } - for table, dataframe in data.items(): self._original_table_columns[table] = dataframe.columns dataframe.columns = dataframe.columns.astype(str) From c498089aba12cd03798e37a7d666584c979e19c7 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 15:34:06 -0500 Subject: [PATCH 19/26] Add back lint --- sdv/single_table/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 31aca249e..a0af94599 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -422,6 +422,7 @@ def fit_processed_data(self, processed_data): len(processed_data.columns), self._synthesizer_id, ) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) if not processed_data.empty: self._fit(processed_data) @@ -453,6 +454,7 @@ def fit(self, data): len(data.columns), self._synthesizer_id, ) + check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) self._check_metadata_updated() self._fitted = False From 11fe4155cca07427a194380e6fe5cdac8d240cb2 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 17:02:59 -0500 Subject: [PATCH 20/26] Unit tests --- sdv/single_table/base.py | 10 ++- tests/unit/metadata/test_multi_table.py | 79 ++++++++++++++++++ tests/unit/metadata/test_single_table.py | 28 +++++++ tests/unit/multi_table/test_base.py | 101 +++++++++++++++++++---- tests/unit/single_table/test_base.py | 27 ++++++ 5 files changed, 227 insertions(+), 18 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index a0af94599..d066585a8 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -389,7 +389,13 @@ def preprocess(self, data): self._original_columns = data.columns data.columns = data.columns.astype(str) break - return self._preprocess(data) + + preprocess_data = self._preprocess(data) + + if not self._original_columns.empty: + data.columns = self._original_columns + + return preprocess_data def _fit(self, processed_data): """Fit the model to the table. @@ -462,8 +468,6 @@ def fit(self, data): self._random_state_set = False processed_data = self.preprocess(data) self.fit_processed_data(processed_data) - if not self._original_columns.empty: - data.columns = self._original_columns def save(self, filepath): """Save this model instance to the given path using cloudpickle. diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 58d0b6975..c32ed7185 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1720,6 +1720,85 @@ def test_load_from_dict(self, mock_singletablemetadata): } ] + @patch('sdv.metadata.multi_table.SingleTableMetadata') + def test_load_from_dict_integer(self, mock_singletablemetadata): + """Test that ``load_from_dict`` returns a instance of ``MultiTableMetadata``. + + Test that when calling the ``load_from_dict`` method a new instance with the passed + python ``dict`` details should be created. Make sure that integers passed in are + turned into strings to ensure metadata is properly typed. + + Setup: + - A dict representing a ``MultiTableMetadata``. + + Mock: + - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` + + Output: + - ``instance`` that contains ``instance.tables`` and ``instance.relationships``. + + Side Effects: + - ``SingleTableMetadata.load_from_dict`` has been called. + """ + # Setup + multitable_metadata = { + 'tables': { + 'accounts': { + 1: {'sdtype': 'numerical'}, + 2: {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + }, + 'branches': { + 1: {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + } + }, + 'relationships': [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 1, + 'child_table_name': 'branches', + 'child_foreign_key': 1, + } + ] + } + + single_table_accounts = { + '1': {'sdtype': 'numerical'}, + '2': {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + } + single_table_branches = { + '1': {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + } + mock_singletablemetadata.load_from_dict.side_effect = [ + single_table_accounts, + single_table_branches + ] + + # Run + instance = MultiTableMetadata.load_from_dict(multitable_metadata) + + # Assert + assert instance.tables == { + 'accounts': single_table_accounts, + 'branches': single_table_branches + } + + assert instance.relationships == [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': '1', + 'child_table_name': 'branches', + 'child_foreign_key': '1', + } + ] + @patch('sdv.metadata.multi_table.json') def test___repr__(self, mock_json): """Test that the ``__repr__`` method. diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index d51b08ecc..8fb8eb058 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -2695,6 +2695,34 @@ def test_load_from_dict(self): assert instance.sequence_index is None assert instance._version == 'SINGLE_TABLE_V1' + def test_load_from_dict_integer(self): + """Test that ``load_from_dict`` returns a instance with the ``dict`` updated objects. + + If the metadata dict contains columns with integers for certain reasons + (e.g. due to missing column names from CSV) make sure they are correctly typed + to strings to ensure metadata is parsed properly. + """ + # Setup + my_metadata = { + 'columns': {1: 'value'}, + 'primary_key': 'pk', + 'alternate_keys': [], + 'sequence_key': None, + 'sequence_index': None, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' + } + + # Run + instance = SingleTableMetadata.load_from_dict(my_metadata) + + # Assert + assert instance.columns == {'1': 'value'} + assert instance.primary_key == 'pk' + assert instance.sequence_key is None + assert instance.alternate_keys == [] + assert instance.sequence_index is None + assert instance._version == 'SINGLE_TABLE_V1' + @patch('sdv.metadata.utils.Path') def test_load_from_json_path_does_not_exist(self, mock_path): """Test the ``load_from_json`` method. diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index fe7dc11ab..8c72a667a 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -779,7 +779,78 @@ def test_preprocess(self): synth_upravna_enota._preprocess.assert_called_once_with(data['upravna_enota']) synth_upravna_enota.update_transformers.assert_called_once_with({'a': None, 'b': None}) - @patch('sdv.multi_table.base.warnings') + def test_preprocess_int_columns(self): + """Test the preprocess method. + + Ensure that data with column names as integers are not changed by + preprocess. + """ + # Setup + metadata_dict = { + 'tables': { + 'first_table': { + 'primary_key': '1', + 'columns': { + '1': {'sdtype': 'id'}, + '2': {'sdtype': 'categorical'}, + 'str': {'sdtype': 'categorical'} + } + }, + 'second_table': { + 'columns': { + '3': {'sdtype': 'id'}, + 'str': {'sdtype': 'categorical'} + } + } + }, + 'relationships': [ + { + 'parent_table_name': 'first_table', + 'parent_primary_key': '1', + 'child_table_name': 'second_table', + 'child_foreign_key': '3' + } + ] + } + metadata = MultiTableMetadata.load_from_dict(metadata_dict) + instance = BaseMultiTableSynthesizer(metadata) + instance.validate = Mock() + instance._table_synthesizers = { + 'first_table': Mock(), + 'second_table': Mock() + } + multi_data = { + 'first_table': pd.DataFrame({ + 1: ['abc', 'def', 'ghi'], + 2: ['x', 'a', 'b'], + 'str': ['John', 'Doe', 'John Doe'], + }), + 'second_table': pd.DataFrame({ + 3: ['abc', 'def', 'ghi'], + 'another': ['John', 'Doe', 'John Doe'], + }), + } + + # Run + instance.preprocess(multi_data) + + # Assert + corrected_frame = { + 'first_table': pd.DataFrame({ + 1: ['abc', 'def', 'ghi'], + 2: ['x', 'a', 'b'], + 'str': ['John', 'Doe', 'John Doe'], + }), + 'second_table': pd.DataFrame({ + 3: ['abc', 'def', 'ghi'], + 'another': ['John', 'Doe', 'John Doe'], + }), + } + + pd.testing.assert_frame_equal(multi_data['first_table'], corrected_frame['first_table']) + pd.testing.assert_frame_equal(multi_data['second_table'], corrected_frame['second_table']) + + @ patch('sdv.multi_table.base.warnings') def test_preprocess_warning(self, mock_warnings): """Test that ``preprocess`` warns the user if the model has already been fitted.""" # Setup @@ -828,7 +899,7 @@ def test_preprocess_warning(self, mock_warnings): "please refit the model using 'fit' or 'fit_processed_data'." ) - @patch('sdv.multi_table.base.datetime') + @ patch('sdv.multi_table.base.datetime') def test_fit_processed_data(self, mock_datetime, caplog): """Test that fit processed data calls ``_augment_tables`` and ``_model_tables``. @@ -914,8 +985,8 @@ def test_fit_processed_data_raises_version_error(self): instance.fit_processed_data.assert_not_called() instance._check_metadata_updated.assert_not_called() - @patch('sdv.multi_table.base.datetime') - @patch('sdv.multi_table.base._validate_foreign_keys_not_null') + @ patch('sdv.multi_table.base.datetime') + @ patch('sdv.multi_table.base._validate_foreign_keys_not_null') def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup @@ -1038,7 +1109,7 @@ def test_sample_validate_input(self): with pytest.raises(SynthesizerInputError, match=msg): instance.sample(scale=scale) - @patch('sdv.multi_table.base.datetime') + @ patch('sdv.multi_table.base.datetime') def test_sample(self, mock_datetime, caplog): """Test that ``sample`` calls the ``_sample`` with the given arguments.""" # Setup @@ -1361,7 +1432,7 @@ def test_add_custom_constraint_class_multi_tables(self): 'custom' ) - @patch('sdv.multi_table.base.version') + @ patch('sdv.multi_table.base.version') def test_get_info(self, mock_version): """Test the correct dictionary is returned. @@ -1409,7 +1480,7 @@ def test_get_info(self, mock_version): 'fitted_sdv_version': '1.0.0' } - @patch('sdv.multi_table.base.version') + @ patch('sdv.multi_table.base.version') def test_get_info_with_enterprise(self, mock_version): """Test the correct dictionary is returned. @@ -1458,8 +1529,8 @@ def test_get_info_with_enterprise(self, mock_version): 'fitted_sdv_enterprise_version': '1.1.0' } - @patch('sdv.multi_table.base.datetime') - @patch('sdv.multi_table.base.cloudpickle') + @ patch('sdv.multi_table.base.datetime') + @ patch('sdv.multi_table.base.cloudpickle') def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup @@ -1482,12 +1553,12 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) - @patch('sdv.multi_table.base.datetime') - @patch('sdv.multi_table.base.generate_synthesizer_id') - @patch('sdv.multi_table.base.check_synthesizer_version') - @patch('sdv.multi_table.base.check_sdv_versions_and_warn') - @patch('sdv.multi_table.base.cloudpickle') - @patch('builtins.open', new_callable=mock_open) + @ patch('sdv.multi_table.base.datetime') + @ patch('sdv.multi_table.base.generate_synthesizer_id') + @ patch('sdv.multi_table.base.check_synthesizer_version') + @ patch('sdv.multi_table.base.check_sdv_versions_and_warn') + @ patch('sdv.multi_table.base.cloudpickle') + @ patch('builtins.open', new_callable=mock_open) def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_warn, mock_check_synthesizer_version, mock_generate_synthesizer_id, mock_datetime, caplog): diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index abaa7c791..eb6475f32 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -339,6 +339,33 @@ def test_preprocess(self, mock_warnings): mock_warnings.warn.assert_called_once_with(expected_warning) instance._preprocess.assert_called_once_with(data) + def test_preprocess_int_columns(self): + """Test the preprocess method. + + Ensure that data with column names as integers are not changed by + preprocess. + """ + # Setup + instance = Mock() + instance._fitted = False + data = pd.DataFrame({ + 1: ['John', 'Doe', 'John Doe'], + 2: ['John', 'Doe', 'John Doe'], + 'str': ['John', 'Doe', 'John Doe'], + }) + + # Run + BaseSingleTableSynthesizer.preprocess(instance, data) + + # Assert + corrected_frame = pd.DataFrame({ + 1: ['John', 'Doe', 'John Doe'], + 2: ['John', 'Doe', 'John Doe'], + 'str': ['John', 'Doe', 'John Doe'], + }) + + pd.testing.assert_frame_equal(data, corrected_frame) + @patch('sdv.single_table.base.DataProcessor') def test__fit(self, mock_data_processor): """Test that ``NotImplementedError`` is being raised.""" From c7c95996fcac5994d738fd5093ed4f5202c625c2 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Fri, 3 May 2024 17:09:12 -0500 Subject: [PATCH 21/26] Remove metadata conversion --- sdv/multi_table/hma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 83b8a7524..a769c3e13 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -494,7 +494,7 @@ def _extract_parameters(self, parent_row, table_name, foreign_key): parent child relationship. """ prefix = f'__{table_name}__{foreign_key}__' - keys = [key for key in parent_row.keys() if str(key).startswith(prefix)] + keys = [key for key in parent_row.keys() if key.startswith(prefix)] new_keys = {key: key[len(prefix):] for key in keys} flat_parameters = parent_row[keys].astype(float).fillna(1e-6) From c23157999dbef9ce5639773098f97a8cb03f52a9 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 7 May 2024 10:15:11 -0500 Subject: [PATCH 22/26] Address comments --- sdv/multi_table/base.py | 22 ++++++++++++++++------ sdv/single_table/base.py | 18 ++++++++++++------ tests/integration/multi_table/test_hma.py | 3 ++- tests/unit/single_table/test_base.py | 2 ++ 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 8dfff0967..569187d11 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -326,6 +326,19 @@ def update_transformers(self, table_name, column_name_to_transformer): self._validate_table_name(table_name) self._table_synthesizers[table_name].update_transformers(column_name_to_transformer) + def _store_and_convert_original_cols(self, data): + list_of_changed_tables = [] + for table, dataframe in data.items(): + self._original_table_columns[table] = dataframe.columns + for column in dataframe.columns: + if isinstance(column, int): + dataframe.columns = dataframe.columns.astype(str) + list_of_changed_tables.append(table) + break + + data[table] = dataframe + return list_of_changed_tables + def preprocess(self, data): """Transform the raw data to numerical space. @@ -337,10 +350,7 @@ def preprocess(self, data): dict: A dictionary with the preprocessed data. """ - for table, dataframe in data.items(): - self._original_table_columns[table] = dataframe.columns - dataframe.columns = dataframe.columns.astype(str) - data[table] = dataframe + list_of_chnaged_tables = self._store_and_convert_original_cols(data) self.validate(data) if self._fitted: @@ -356,8 +366,8 @@ def preprocess(self, data): self._assign_table_transformers(synthesizer, table_name, table_data) processed_data[table_name] = synthesizer._preprocess(table_data) - for table, dataframe in data.items(): - dataframe.columns = self._original_table_columns[table] + for table in list_of_chnaged_tables: + data[table].columns = self._original_table_columns[table] return processed_data diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index d066585a8..3e496ba77 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -367,6 +367,16 @@ def _preprocess(self, data): self._data_processor.fit(data) return self._data_processor.transform(data) + def _store_and_convert_original_cols(self, data): + # Transform in place to avoid possible large copy of data + for column in data.columns: + if isinstance(column, int): + self._original_columns = data.columns + data.columns = data.columns.astype(str) + return True + + return False + def preprocess(self, data): """Transform the raw data to numerical space. @@ -384,15 +394,11 @@ def preprocess(self, data): "please refit the model using 'fit' or 'fit_processed_data'." ) - for column in data.columns: - if isinstance(column, int): - self._original_columns = data.columns - data.columns = data.columns.astype(str) - break + is_converted = self._store_and_convert_original_cols(data) preprocess_data = self._preprocess(data) - if not self._original_columns.empty: + if is_converted: data.columns = self._original_columns return preprocess_data diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 908be40ec..0476c8a50 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1742,7 +1742,7 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id): def test_fit_and_sample_numerical_col_names(): """Test fitting/sampling when column names are integers""" - # Setup data + # Setup num_rows = 50 num_cols = 10 num_tables = 2 @@ -1784,6 +1784,7 @@ def test_fit_and_sample_numerical_col_names(): assert first_sample['1'].columns.tolist() == data['1'].columns.tolist() assert second_sample['0'].columns.tolist() == data['0'].columns.tolist() assert second_sample['1'].columns.tolist() == data['1'].columns.tolist() + # Assert with pytest.raises(AssertionError): pd.testing.assert_frame_equal(first_sample['0'], second_sample['0']) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index eb6475f32..2395186d0 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -324,6 +324,7 @@ def test_preprocess(self, mock_warnings): # Setup instance = Mock() instance._fitted = True + instance._store_and_convert_original_cols.return_value = False data = pd.DataFrame({ 'name': ['John', 'Doe', 'John Doe'] }) @@ -348,6 +349,7 @@ def test_preprocess_int_columns(self): # Setup instance = Mock() instance._fitted = False + instance._original_columns = pd.Index([1, 2, 'str']) data = pd.DataFrame({ 1: ['John', 'Doe', 'John Doe'], 2: ['John', 'Doe', 'John Doe'], From 429c26a4543547c485a7d62539f24d775ab2f51c Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 7 May 2024 11:36:08 -0500 Subject: [PATCH 23/26] Fixed typo in var name --- sdv/multi_table/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 569187d11..dc92b19fa 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -350,7 +350,7 @@ def preprocess(self, data): dict: A dictionary with the preprocessed data. """ - list_of_chnaged_tables = self._store_and_convert_original_cols(data) + list_of_changed_tables = self._store_and_convert_original_cols(data) self.validate(data) if self._fitted: @@ -366,7 +366,7 @@ def preprocess(self, data): self._assign_table_transformers(synthesizer, table_name, table_data) processed_data[table_name] = synthesizer._preprocess(table_data) - for table in list_of_chnaged_tables: + for table in list_of_changed_tables: data[table].columns = self._original_table_columns[table] return processed_data From 92b07c29066e19bc66af01fb93683f589eb1e8c7 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Tue, 7 May 2024 11:40:19 -0500 Subject: [PATCH 24/26] Remove space between @ and patch --- tests/unit/multi_table/test_base.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 8c72a667a..f41df7800 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -850,7 +850,7 @@ def test_preprocess_int_columns(self): pd.testing.assert_frame_equal(multi_data['first_table'], corrected_frame['first_table']) pd.testing.assert_frame_equal(multi_data['second_table'], corrected_frame['second_table']) - @ patch('sdv.multi_table.base.warnings') + @patch('sdv.multi_table.base.warnings') def test_preprocess_warning(self, mock_warnings): """Test that ``preprocess`` warns the user if the model has already been fitted.""" # Setup @@ -899,7 +899,7 @@ def test_preprocess_warning(self, mock_warnings): "please refit the model using 'fit' or 'fit_processed_data'." ) - @ patch('sdv.multi_table.base.datetime') + @patch('sdv.multi_table.base.datetime') def test_fit_processed_data(self, mock_datetime, caplog): """Test that fit processed data calls ``_augment_tables`` and ``_model_tables``. @@ -985,8 +985,8 @@ def test_fit_processed_data_raises_version_error(self): instance.fit_processed_data.assert_not_called() instance._check_metadata_updated.assert_not_called() - @ patch('sdv.multi_table.base.datetime') - @ patch('sdv.multi_table.base._validate_foreign_keys_not_null') + @patch('sdv.multi_table.base.datetime') + @patch('sdv.multi_table.base._validate_foreign_keys_not_null') def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup @@ -1109,7 +1109,7 @@ def test_sample_validate_input(self): with pytest.raises(SynthesizerInputError, match=msg): instance.sample(scale=scale) - @ patch('sdv.multi_table.base.datetime') + @patch('sdv.multi_table.base.datetime') def test_sample(self, mock_datetime, caplog): """Test that ``sample`` calls the ``_sample`` with the given arguments.""" # Setup @@ -1432,7 +1432,7 @@ def test_add_custom_constraint_class_multi_tables(self): 'custom' ) - @ patch('sdv.multi_table.base.version') + @patch('sdv.multi_table.base.version') def test_get_info(self, mock_version): """Test the correct dictionary is returned. @@ -1480,7 +1480,7 @@ def test_get_info(self, mock_version): 'fitted_sdv_version': '1.0.0' } - @ patch('sdv.multi_table.base.version') + @patch('sdv.multi_table.base.version') def test_get_info_with_enterprise(self, mock_version): """Test the correct dictionary is returned. @@ -1529,8 +1529,8 @@ def test_get_info_with_enterprise(self, mock_version): 'fitted_sdv_enterprise_version': '1.1.0' } - @ patch('sdv.multi_table.base.datetime') - @ patch('sdv.multi_table.base.cloudpickle') + @patch('sdv.multi_table.base.datetime') + @patch('sdv.multi_table.base.cloudpickle') def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): """Test that the synthesizer is saved correctly.""" # Setup @@ -1553,12 +1553,12 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): ' Synthesizer id: BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' ) - @ patch('sdv.multi_table.base.datetime') - @ patch('sdv.multi_table.base.generate_synthesizer_id') - @ patch('sdv.multi_table.base.check_synthesizer_version') - @ patch('sdv.multi_table.base.check_sdv_versions_and_warn') - @ patch('sdv.multi_table.base.cloudpickle') - @ patch('builtins.open', new_callable=mock_open) + @patch('sdv.multi_table.base.datetime') + @patch('sdv.multi_table.base.generate_synthesizer_id') + @patch('sdv.multi_table.base.check_synthesizer_version') + @patch('sdv.multi_table.base.check_sdv_versions_and_warn') + @patch('sdv.multi_table.base.cloudpickle') + @patch('builtins.open', new_callable=mock_open) def test_load(self, mock_file, cloudpickle_mock, mock_check_sdv_versions_and_warn, mock_check_synthesizer_version, mock_generate_synthesizer_id, mock_datetime, caplog): From 2fb48b822ecabf1c163efc0a45a7c22df1059b5d Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 8 May 2024 09:56:48 -0500 Subject: [PATCH 25/26] Add backward compat checks for sampling older models --- sdv/multi_table/base.py | 5 +++-- sdv/single_table/base.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index dc92b19fa..1192147d9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -499,9 +499,10 @@ def sample(self, scale=1.0): total_rows += len(table) total_columns += len(table.columns) + table_columns = getattr(self, '_original_table_columns', {}) for table in sampled_data: - if table in self._original_table_columns: - sampled_data[table].columns = self._original_table_columns[table] + if table in table_columns: + sampled_data[table].columns = table_columns[table] SYNTHESIZER_LOGGER.info( '\nSample:\n' diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index cf7cdda8d..d61858576 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -909,7 +909,8 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file show_progress_bar=show_progress_bar ) - if not self._original_columns.empty: + original_columns = getattr(self, '_original_columns', pd.Index([])) + if not original_columns.empty: sampled_data.columns = self._original_columns SYNTHESIZER_LOGGER.info( From 6f6a4430d81dc4e54d2dd724179314b3f06fe160 Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 8 May 2024 10:04:51 -0500 Subject: [PATCH 26/26] Fix merge conflict --- tests/integration/single_table/test_base.py | 76 --------------------- 1 file changed, 76 deletions(-) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index ac45c51e8..1abb62f05 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -779,82 +779,6 @@ def test_fit_raises_version_error(): instance.fit(data) -@patch('sdv.single_table.base.generate_synthesizer_id') -@patch('sdv.single_table.base.datetime') -def test_synthesizer_logger(mock_datetime, mock_generate_id): - """Test that the synthesizer logger logs the expected messages.""" - # Setup - store_path = Path(platformdirs.user_data_dir('sdv', 'sdv-dev')) - file_name = 'sdv_logs.log' - - synth_id = 'GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' - mock_generate_id.return_value = synth_id - mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' - data = pd.DataFrame({ - 'col 1': [1, 2, 3], - 'col 2': [4, 5, 6], - 'col 3': ['a', 'b', 'c'], - }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - - # Run - instance = GaussianCopulaSynthesizer(metadata) - - # Assert - with open(store_path / file_name) as f: - instance_lines = f.readlines()[-4:] - - assert ''.join(instance_lines) == ( - 'Instance:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - # Run - instance.fit(data) - - # Assert - with open(store_path / file_name) as f: - fit_lines = f.readlines()[-17:] - - assert ''.join(fit_lines) == ( - 'Fit:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Statistics of the fit data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 3\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - '\nFit processed data:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Statistics of the fit processed data:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 3\n' - ' Total number of columns: 3\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - # Run - instance.sample(100) - with open(store_path / file_name) as f: - sample_lines = f.readlines()[-8:] - - assert ''.join(sample_lines) == ( - 'Sample:\n' - ' Timestamp: 2024-04-19 16:20:10.037183\n' - ' Synthesizer class name: GaussianCopulaSynthesizer\n' - ' Statistics of the sample size:\n' - ' Total number of tables: 1\n' - ' Total number of rows: 100\n' - ' Total number of columns: 3\n' - ' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n' - ) - - SYNTHESIZERS_CLASSES = [ pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'), pytest.param(TVAESynthesizer, id='TVAESynthesizer'),