Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert metadata columns from integer to strings to ensure that SDV works properly #1989

Merged
merged 11 commits into from
May 9, 2024
6 changes: 6 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def _detect_relationships(self):
def detect_table_from_dataframe(self, table_name, data):
"""Detect the metadata for a table from a dataframe.

This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``,
for a specified table. All data columns are converted to strings.
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved

Args:
table_name (str):
Name of the table to detect.
Expand All @@ -539,6 +542,9 @@ def detect_table_from_dataframe(self, table_name, data):
def detect_from_dataframes(self, data):
"""Detect the metadata for all tables in a dictionary of dataframes.

This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data columns are converted to strings.

Args:
data (dict):
Dictionary of table names to dataframes.
Expand Down
11 changes: 7 additions & 4 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ def _detect_columns(self, data):
data (pandas.DataFrame):
The data to be analyzed.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
first_pii_field = None
for field in data:
column_data = data[field]
Expand Down Expand Up @@ -573,11 +575,13 @@ def _detect_columns(self, data):
self.primary_key = first_pii_field

self._updated = True
data.columns = old_columns

def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.

This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data columns are converted to strings.

Args:
data (pandas.DataFrame):
Expand All @@ -588,10 +592,8 @@ def detect_from_dataframe(self, data):
'Metadata already exists. Create a new ``SingleTableMetadata`` '
'object to detect from other data sources.'
)
old_columns = data.columns
data.columns = data.columns.astype(str)

self._detect_columns(data)
data.columns = old_columns

LOGGER.info('Detected metadata:')
LOGGER.info(json.dumps(self.to_dict(), indent=4))
Expand Down Expand Up @@ -1234,7 +1236,8 @@ def load_from_dict(cls, metadata_dict):
Python dictionary representing a ``SingleTableMetadata`` object.

Returns:
Instance of ``SingleTableMetadata``.
Instance of ``SingleTableMetadata``. Column names are converted to
string type.
"""
instance = cls()
for key in instance._KEYS:
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _estimate_columns_traversal(cls, metadata, table_name,
columns_per_table[table_name] += \
cls._get_num_extended_columns(
metadata, child_name, table_name, columns_per_table, distributions
)
)

visited.add(table_name)

Expand Down
59 changes: 59 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,3 +1729,62 @@ def test_fit_and_sample_numerical_col_names():
# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(first_sample['0'], second_sample['0'])


def test_detect_from_dataframe_numerical_col():
"""Test that metadata detection of integer columns work."""
# Setup
parent_data = pd.DataFrame({
1: [1000, 1001, 1002],
2: [2, 3, 4],
'categorical_col': ['a', 'b', 'a'],
})
child_data = pd.DataFrame({
3: [1000, 1001, 1000],
4: [1, 2, 3]
})
data = {
'parent_data': parent_data,
'child_data': child_data,
}
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('parent_data', parent_data)
metadata.detect_table_from_dataframe('child_data', child_data)
metadata.update_column('parent_data', '1', sdtype='id')
metadata.update_column('child_data', '3', sdtype='id')
metadata.update_column('child_data', '4', sdtype='id')
metadata.set_primary_key('parent_data', '1')
metadata.set_primary_key('child_data', '4')
metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
child_foreign_key='3',
child_table_name='child_data'
)

test_metadata = MultiTableMetadata()
test_metadata.detect_from_dataframes(data)
test_metadata.update_column('parent_data', '1', sdtype='id')
test_metadata.update_column('child_data', '3', sdtype='id')
test_metadata.update_column('child_data', '4', sdtype='id')
test_metadata.set_primary_key('parent_data', '1')
test_metadata.set_primary_key('child_data', '4')
test_metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
child_foreign_key='3',
child_table_name='child_data'
)

# Run
instance = HMASynthesizer(metadata)
instance.fit(data)
sample = instance.sample(5)

# Assert
assert test_metadata.to_dict() == metadata.to_dict()
assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist()
assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist()

test_metadata = MultiTableMetadata()
test_metadata.detect_from_dataframes(data)
17 changes: 17 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,20 @@ def test_sample_not_fitted(synthesizer):
# Run and Assert
with pytest.raises(SamplingError, match=expected_message):
synthesizer.sample(10)


@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_detect_from_dataframe_numerical_col(synthesizer_class):
"""Test that metadata detection of integer columns work."""
# Setup
data = pd.DataFrame({
1: [1, 2, 3],
2: [4, 5, 6],
3: ['a', 'b', 'c'],
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
instance = synthesizer_class(metadata)
instance.fit(data)
sample = instance.sample(5)
assert sample.columns.tolist() == data.columns.tolist()
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
Loading