Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert metadata columns from integer to strings to ensure that SDV works properly #1989

Merged
merged 11 commits into from
May 9, 2024
6 changes: 6 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def _detect_relationships(self):
def detect_table_from_dataframe(self, table_name, data):
"""Detect the metadata for a table from a dataframe.

This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``,
for a specified table. All data column names are converted to strings.

Args:
table_name (str):
Name of the table to detect.
Expand All @@ -539,6 +542,9 @@ def detect_table_from_dataframe(self, table_name, data):
def detect_from_dataframes(self, data):
"""Detect the metadata for all tables in a dictionary of dataframes.

This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data column names are converted to strings.

Args:
data (dict):
Dictionary of table names to dataframes.
Expand Down
7 changes: 6 additions & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ def _detect_columns(self, data):
data (pandas.DataFrame):
The data to be analyzed.
"""
old_columns = data.columns
data.columns = data.columns.astype(str)
first_pii_field = None
for field in data:
column_data = data[field]
Expand Down Expand Up @@ -573,11 +575,13 @@ def _detect_columns(self, data):
self.primary_key = first_pii_field

self._updated = True
data.columns = old_columns

def detect_from_dataframe(self, data):
"""Detect the metadata from a ``pd.DataFrame`` object.

This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``.
All data column names are converted to strings.

Args:
data (pandas.DataFrame):
Expand Down Expand Up @@ -1232,7 +1236,8 @@ def load_from_dict(cls, metadata_dict):
Python dictionary representing a ``SingleTableMetadata`` object.

Returns:
Instance of ``SingleTableMetadata``.
Instance of ``SingleTableMetadata``. Column names are converted to
string type.
"""
instance = cls()
for key in instance._KEYS:
Expand Down
59 changes: 59 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,3 +1729,62 @@ def test_fit_and_sample_numerical_col_names():
# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(first_sample['0'], second_sample['0'])


def test_detect_from_dataframe_numerical_col():
"""Test that metadata detection of integer columns work."""
# Setup
parent_data = pd.DataFrame({
1: [1000, 1001, 1002],
2: [2, 3, 4],
'categorical_col': ['a', 'b', 'a'],
})
child_data = pd.DataFrame({
3: [1000, 1001, 1000],
4: [1, 2, 3]
})
data = {
'parent_data': parent_data,
'child_data': child_data,
}
metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('parent_data', parent_data)
metadata.detect_table_from_dataframe('child_data', child_data)
metadata.update_column('parent_data', '1', sdtype='id')
metadata.update_column('child_data', '3', sdtype='id')
metadata.update_column('child_data', '4', sdtype='id')
metadata.set_primary_key('parent_data', '1')
metadata.set_primary_key('child_data', '4')
metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
child_foreign_key='3',
child_table_name='child_data'
)

test_metadata = MultiTableMetadata()
test_metadata.detect_from_dataframes(data)
test_metadata.update_column('parent_data', '1', sdtype='id')
test_metadata.update_column('child_data', '3', sdtype='id')
test_metadata.update_column('child_data', '4', sdtype='id')
test_metadata.set_primary_key('parent_data', '1')
test_metadata.set_primary_key('child_data', '4')
test_metadata.add_relationship(
parent_primary_key='1',
parent_table_name='parent_data',
child_foreign_key='3',
child_table_name='child_data'
)

# Run
instance = HMASynthesizer(metadata)
instance.fit(data)
sample = instance.sample(5)

# Assert
assert test_metadata.to_dict() == metadata.to_dict()
assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist()
assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist()

test_metadata = MultiTableMetadata()
test_metadata.detect_from_dataframes(data)
17 changes: 17 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,3 +831,20 @@ def test_sample_not_fitted(synthesizer):
# Run and Assert
with pytest.raises(SamplingError, match=expected_message):
synthesizer.sample(10)


@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_detect_from_dataframe_numerical_col(synthesizer_class):
"""Test that metadata detection of integer columns work."""
# Setup
data = pd.DataFrame({
1: [1, 2, 3],
2: [4, 5, 6],
3: ['a', 'b', 'c'],
})
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
instance = synthesizer_class(metadata)
instance.fit(data)
sample = instance.sample(5)
assert sample.columns.tolist() == data.columns.tolist()
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
81 changes: 81 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,87 @@ def test_detect_from_dataframe(self, mock_log):
]
mock_log.info.assert_has_calls(expected_log_calls)

@patch('sdv.metadata.single_table.LOGGER')
def test_detect_from_dataframe_numerical_columns(self, mock_log):
"""Test the detect from dataframe with columns that are integers"""
# Setup
num_rows = 100
num_cols = 20
values = {i + 1: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)}
data = pd.DataFrame(values)
correct_metadata = {
'columns': {
'1': {
'sdtype': 'numerical'
},
'2': {
'sdtype': 'numerical'
},
'3': {
'sdtype': 'numerical'
},
'4': {
'sdtype': 'numerical'
},
'5': {
'sdtype': 'numerical'
},
'6': {
'sdtype': 'numerical'
},
'7': {
'sdtype': 'numerical'
},
'8': {
'sdtype': 'numerical'
},
'9': {
'sdtype': 'numerical'
},
'10': {
'sdtype': 'numerical'
},
'11': {
'sdtype': 'numerical'
},
'12': {
'sdtype': 'numerical'
},
'13': {
'sdtype': 'numerical'
},
'14': {
'sdtype': 'numerical'
},
'15': {
'sdtype': 'numerical'
},
'16': {
'sdtype': 'numerical'
},
'17': {
'sdtype': 'numerical'
},
'18': {
'sdtype': 'numerical'
},
'19': {
'sdtype': 'numerical'
},
'20': {
'sdtype': 'numerical'
}
},
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'
}

# Run
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)

# Assert
assert correct_metadata == metadata.to_dict()

def test_detect_from_csv_raises_error(self):
"""Test the ``detect_from_csv`` method.

Expand Down
Loading