Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert integer column names to strings to allow for default column names #1976

Merged
merged 31 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9fa8fac
WIP trying to convert everything to string before processing
lajohn4747 May 1, 2024
00ebb3a
Fix test
lajohn4747 May 1, 2024
68f8761
Remove all prints
lajohn4747 May 1, 2024
4d89c58
Fix lint
lajohn4747 May 1, 2024
d881306
Fix lint
lajohn4747 May 1, 2024
79b6853
Add tests
lajohn4747 May 1, 2024
e3a6cd2
Merge branch 'main' into issue_1935_column_integer_type_error
lajohn4747 May 1, 2024
67df9eb
Update Test
lajohn4747 May 3, 2024
6f8c497
Merge
lajohn4747 May 3, 2024
c64da83
Remove string changes
lajohn4747 May 3, 2024
8d617bd
remove old code
lajohn4747 May 3, 2024
8982a70
Remove str conversions
lajohn4747 May 3, 2024
099dad9
Debug
lajohn4747 May 3, 2024
70cbcd8
Merge branch 'main' into fix
lajohn4747 May 3, 2024
9ec0b15
Fixed processing
lajohn4747 May 3, 2024
f821a95
Fix tests
lajohn4747 May 3, 2024
9f521f8
Remove lint
lajohn4747 May 3, 2024
971f685
Fix
lajohn4747 May 3, 2024
058f55b
Merge branch 'fix' into issue_1935_column_integer_type_error
lajohn4747 May 3, 2024
9abb32b
Fix merge
lajohn4747 May 3, 2024
3281bad
Remove table transformation
lajohn4747 May 3, 2024
c498089
Add back lint
lajohn4747 May 3, 2024
11fe415
Unit tests
lajohn4747 May 3, 2024
c7c9599
Remove metadata conversion
lajohn4747 May 3, 2024
c231579
Address comments
lajohn4747 May 7, 2024
429c26a
Fixed typo in var name
lajohn4747 May 7, 2024
2de097e
Merge branch 'main' into issue_1935_column_integer_type_error
lajohn4747 May 7, 2024
92b07c2
Remove space between @ and patch
lajohn4747 May 7, 2024
2fb48b8
Add backward compat checks for sampling older models
lajohn4747 May 8, 2024
4cb62bd
Merge
lajohn4747 May 8, 2024
6f6a443
Fix merge conflict
lajohn4747 May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,10 +1025,15 @@ def _set_metadata_dict(self, metadata):
Python dictionary representing a ``MultiTableMetadata`` object.
"""
for table_name, table_dict in metadata.get('tables', {}).items():
self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict)
self.tables[str(table_name)] = SingleTableMetadata.load_from_dict(table_dict)

for relationship in metadata.get('relationships', []):
self.relationships.append(relationship)
type_safe_relationships = {
key: str(value)
if not isinstance(value, str)
else value for key, value in relationship.items()
}
self.relationships.append(type_safe_relationships)

@classmethod
def load_from_dict(cls, metadata_dict):
Expand Down
6 changes: 6 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,12 @@ def load_from_dict(cls, metadata_dict):
for key in instance._KEYS:
value = deepcopy(metadata_dict.get(key))
if value:
if key == 'columns':
value = {
str(key)
if not isinstance(key, str)
else key: col for key, col in value.items()
}
setattr(instance, f'{key}', value)

return instance
Expand Down
13 changes: 10 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,16 @@ def fit(self, data):
Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format
(before any transformations).
"""
type_safe_data = {
str(key) if not isinstance(key, str) else key: value for key, value in data.items()
}
total_rows = 0
total_columns = 0
for table in data.values():
for table, dataframe in type_safe_data.items():
dataframe.columns = dataframe.columns.astype(str)
type_safe_data[table] = dataframe

for table in type_safe_data.values():
total_rows += len(table)
total_columns += len(table.columns)

Expand All @@ -440,10 +447,10 @@ def fit(self, data):
self._synthesizer_id,
)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
_validate_foreign_keys_not_null(self.metadata, data)
_validate_foreign_keys_not_null(self.metadata, type_safe_data)
self._check_metadata_updated()
self._fitted = False
processed_data = self.preprocess(data)
processed_data = self.preprocess(type_safe_data)
self._print(text='\n', end='')
self.fit_processed_data(processed_data)

Expand Down
2 changes: 1 addition & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def fit(self, data):
len(data.columns),
self._synthesizer_id,
)

data.columns = data.columns.astype(str)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
self._check_metadata_updated()
self._fitted = False
Expand Down
2 changes: 1 addition & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def unflatten_dict(flat):

else:
subdict = unflattened.setdefault(key, {})
if subkey.isdigit():
if subkey.isdigit() and key != 'univariates':
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
subkey = int(subkey)

inner = subdict.setdefault(subkey, {})
Expand Down
45 changes: 45 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,3 +1738,48 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):
' Total number of columns: 15\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)


def test_fit_and_sample_numerical_col_names():
"""Test fit and sampling when column names are integers"""
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
# Setup data
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
num_rows = 50
num_cols = 10
num_tables = 2
data = {}
for i in range(num_tables):
values = {j: np.random.randint(0, 100, size=num_rows) for j in range(num_cols)}
data[i] = pd.DataFrame(values)

primary_key = pd.DataFrame({1: range(num_rows)})
primary_key_2 = pd.DataFrame({2: range(num_rows)})
data[0][1] = primary_key
data[1][1] = primary_key
data[1][2] = primary_key_2
metadata = MultiTableMetadata()
metadata_dict = {'tables': {}}
for table_idx in range(num_tables):
metadata_dict['tables'][table_idx] = {'columns': {}}
for i in range(num_cols):
metadata_dict['tables'][table_idx]['columns'][i] = {'sdtype': 'numerical'}
metadata_dict['tables'][0]['columns'][1] = {'sdtype': 'id'}
metadata_dict['tables'][1]['columns'][2] = {'sdtype': 'id'}
metadata_dict['relationships'] = [
{
'parent_table_name': 0,
'parent_primary_key': 1,
'child_table_name': 1,
'child_foreign_key': 2
}
]
metadata = MultiTableMetadata.load_from_dict(metadata_dict)
metadata.set_primary_key('0', '1')

# Run
synth = HMASynthesizer(metadata)
synth.fit(data)
first_sample = synth.sample()
second_sample = synth.sample()

with pytest.raises(AssertionError):
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
pd.testing.assert_frame_equal(first_sample['0'], second_sample['0'])
34 changes: 34 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,37 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)


def test_fit_and_sample_numerical_col_names():
"""Test fit and sampling when column names are integers"""
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
# Setup
num_rows = 50
num_cols = 10
values = {
i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)
}
data = pd.DataFrame(values)
metadata = SingleTableMetadata()
metadata_dict = {'columns': {}}
for i in range(num_cols):
metadata_dict['columns'][i] = {'sdtype': 'numerical'}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)

# Run

lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
synthesizers = [
CTGANSynthesizer,
TVAESynthesizer,
GaussianCopulaSynthesizer,
CopulaGANSynthesizer
]
for synthesizer_class in synthesizers:
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
synth = synthesizer_class(metadata)
synth.fit(data)
sample_1 = synth.sample(10)
sample_2 = synth.sample(10)

# Assert
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(sample_1, sample_2)
Loading