Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert metadata columns from integer to strings to ensure that SDV works properly #1989

Merged
merged 11 commits into from
May 9, 2024
4 changes: 3 additions & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,10 @@ def detect_from_dataframe(self, data):
'Metadata already exists. Create a new ``SingleTableMetadata`` '
'object to detect from other data sources.'
)

old_columns = data.columns
data.columns = data.columns.astype(str)
self._detect_columns(data)
data.columns = old_columns

LOGGER.info('Detected metadata:')
LOGGER.info(json.dumps(self.to_dict(), indent=4))
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _estimate_columns_traversal(cls, metadata, table_name,
columns_per_table[table_name] += \
cls._get_num_extended_columns(
metadata, child_name, table_name, columns_per_table, distributions
)
)
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved

visited.add(table_name)

Expand Down
81 changes: 81 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,87 @@ def test_detect_from_dataframe(self, mock_log):
]
mock_log.info.assert_has_calls(expected_log_calls)

@patch('sdv.metadata.single_table.LOGGER')
def test_detect_from_dataframe_numerical_columns(self, mock_log):
"""Test the detect from dataframe with columns that are integers"""
# Setup
num_rows = 100
num_cols = 20
values = {i + 1: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)}
data = pd.DataFrame(values)
correct_metadata = {
'columns': {
'1': {
'sdtype': 'numerical'
},
'2': {
'sdtype': 'numerical'
},
'3': {
'sdtype': 'numerical'
},
'4': {
'sdtype': 'numerical'
},
'5': {
'sdtype': 'numerical'
},
'6': {
'sdtype': 'numerical'
},
'7': {
'sdtype': 'numerical'
},
'8': {
'sdtype': 'numerical'
},
'9': {
'sdtype': 'numerical'
},
'10': {
'sdtype': 'numerical'
},
'11': {
'sdtype': 'numerical'
},
'12': {
'sdtype': 'numerical'
},
'13': {
'sdtype': 'numerical'
},
'14': {
'sdtype': 'numerical'
},
'15': {
'sdtype': 'numerical'
},
'16': {
'sdtype': 'numerical'
},
'17': {
'sdtype': 'numerical'
},
'18': {
'sdtype': 'numerical'
},
'19': {
'sdtype': 'numerical'
},
'20': {
'sdtype': 'numerical'
}
},
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'
}

# Run
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)

# Assert
assert correct_metadata == metadata.to_dict()

def test_detect_from_csv_raises_error(self):
"""Test the ``detect_from_csv`` method.

Expand Down
Loading