Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert integer column names to strings to allow for default column names #1976

Merged
merged 31 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9fa8fac
WIP trying to convert everything to string before processing
lajohn4747 May 1, 2024
00ebb3a
Fix test
lajohn4747 May 1, 2024
68f8761
Remove all prints
lajohn4747 May 1, 2024
4d89c58
Fix lint
lajohn4747 May 1, 2024
d881306
Fix lint
lajohn4747 May 1, 2024
79b6853
Add tests
lajohn4747 May 1, 2024
e3a6cd2
Merge branch 'main' into issue_1935_column_integer_type_error
lajohn4747 May 1, 2024
67df9eb
Update Test
lajohn4747 May 3, 2024
6f8c497
Merge
lajohn4747 May 3, 2024
c64da83
Remove string changes
lajohn4747 May 3, 2024
8d617bd
remove old code
lajohn4747 May 3, 2024
8982a70
Remove str conversions
lajohn4747 May 3, 2024
099dad9
Debug
lajohn4747 May 3, 2024
70cbcd8
Merge branch 'main' into fix
lajohn4747 May 3, 2024
9ec0b15
Fixed processing
lajohn4747 May 3, 2024
f821a95
Fix tests
lajohn4747 May 3, 2024
9f521f8
Remove lint
lajohn4747 May 3, 2024
971f685
Fix
lajohn4747 May 3, 2024
058f55b
Merge branch 'fix' into issue_1935_column_integer_type_error
lajohn4747 May 3, 2024
9abb32b
Fix merge
lajohn4747 May 3, 2024
3281bad
Remove table transformation
lajohn4747 May 3, 2024
c498089
Add back lint
lajohn4747 May 3, 2024
11fe415
Unit tests
lajohn4747 May 3, 2024
c7c9599
Remove metadata conversion
lajohn4747 May 3, 2024
c231579
Address comments
lajohn4747 May 7, 2024
429c26a
Fixed typo in var name
lajohn4747 May 7, 2024
2de097e
Merge branch 'main' into issue_1935_column_integer_type_error
lajohn4747 May 7, 2024
92b07c2
Remove space between @ and patch
lajohn4747 May 7, 2024
2fb48b8
Add backward compat checks for sampling older models
lajohn4747 May 8, 2024
4cb62bd
Merge
lajohn4747 May 8, 2024
6f6a443
Fix merge conflict
lajohn4747 May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def _validate_foreign_keys_not_null(metadata, data):
invalid_tables = defaultdict(list)
for table_name, table_data in data.items():
for foreign_key in metadata._get_all_foreign_keys(table_name):
if foreign_key not in table_data and int(foreign_key) in table_data:
foreign_key = int(foreign_key)
if table_data[foreign_key].isna().any():
invalid_tables[table_name].append(foreign_key)

Expand Down
7 changes: 6 additions & 1 deletion sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,12 @@ def _set_metadata_dict(self, metadata):
self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict)

for relationship in metadata.get('relationships', []):
self.relationships.append(relationship)
type_safe_relationships = {
key: str(value)
if not isinstance(value, str)
else value for key, value in relationship.items()
}
self.relationships.append(type_safe_relationships)

@classmethod
def load_from_dict(cls, metadata_dict):
Expand Down
6 changes: 6 additions & 0 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,6 +1238,12 @@ def load_from_dict(cls, metadata_dict):
for key in instance._KEYS:
value = deepcopy(metadata_dict.get(key))
if value:
if key == 'columns':
value = {
str(key)
if not isinstance(key, str)
else key: col for key, col in value.items()
}
setattr(instance, f'{key}', value)

return instance
Expand Down
13 changes: 13 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.extended_columns = defaultdict(dict)
self._table_synthesizers = {}
self._table_parameters = defaultdict(dict)
self._original_table_columns = {}
if synthesizer_kwargs is not None:
warn_message = (
'The `synthesizer_kwargs` parameter is deprecated as of SDV 1.2.0 and does not '
Expand Down Expand Up @@ -336,6 +337,11 @@ def preprocess(self, data):
dict:
A dictionary with the preprocessed data.
"""
for table, dataframe in data.items():
self._original_table_columns[table] = dataframe.columns
dataframe.columns = dataframe.columns.astype(str)
data[table] = dataframe

self.validate(data)
if self._fitted:
warnings.warn(
Expand All @@ -350,6 +356,9 @@ def preprocess(self, data):
self._assign_table_transformers(synthesizer, table_name, table_data)
processed_data[table_name] = synthesizer._preprocess(table_data)

for table, dataframe in data.items():
dataframe.columns = self._original_table_columns[table]

return processed_data

def _model_tables(self, augmented_data):
Expand Down Expand Up @@ -480,6 +489,10 @@ def sample(self, scale=1.0):
total_rows += len(table)
total_columns += len(table.columns)

for table in sampled_data:
if table in self._original_table_columns:
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
sampled_data[table].columns = self._original_table_columns[table]

SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
Expand Down
19 changes: 17 additions & 2 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
enforce_min_max_values=self.enforce_min_max_values,
locales=self.locales,
)
self._original_columns = pd.Index([])
self._fitted = False
self._random_state_set = False
self._update_default_transformers()
Expand Down Expand Up @@ -383,7 +384,18 @@ def preprocess(self, data):
"please refit the model using 'fit' or 'fit_processed_data'."
)

return self._preprocess(data)
for column in data.columns:
if isinstance(column, int):
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
self._original_columns = data.columns
data.columns = data.columns.astype(str)
break

preprocess_data = self._preprocess(data)

if not self._original_columns.empty:
data.columns = self._original_columns
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved

return preprocess_data

def _fit(self, processed_data):
"""Fit the model to the table.
Expand Down Expand Up @@ -454,7 +466,7 @@ def fit(self, data):
self._fitted = False
self._data_processor.reset_sampling()
self._random_state_set = False
processed_data = self._preprocess(data)
processed_data = self.preprocess(data)
self.fit_processed_data(processed_data)

def save(self, filepath):
Expand Down Expand Up @@ -884,6 +896,9 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
show_progress_bar=show_progress_bar
)

if not self._original_columns.empty:
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
sampled_data.columns = self._original_columns

SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
Expand Down
2 changes: 1 addition & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def unflatten_dict(flat):

else:
subdict = unflattened.setdefault(key, {})
if subkey.isdigit():
if subkey.isdigit() and key != 'univariates':
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
subkey = int(subkey)

inner = subdict.setdefault(subkey, {})
Expand Down
49 changes: 49 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1738,3 +1738,52 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):
' Total number of columns: 15\n'
' Synthesizer id: HMASynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)


def test_fit_and_sample_numerical_col_names():
"""Test fitting/sampling when column names are integers"""
# Setup data
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
num_rows = 50
num_cols = 10
num_tables = 2
data = {}
for i in range(num_tables):
values = {j: np.random.randint(0, 100, size=num_rows) for j in range(num_cols)}
data[str(i)] = pd.DataFrame(values)

primary_key = pd.DataFrame({1: range(num_rows)})
primary_key_2 = pd.DataFrame({2: range(num_rows)})
data['0'][1] = primary_key
data['1'][1] = primary_key
data['1'][2] = primary_key_2
metadata = MultiTableMetadata()
metadata_dict = {'tables': {}}
for table_idx in range(num_tables):
metadata_dict['tables'][str(table_idx)] = {'columns': {}}
for i in range(num_cols):
metadata_dict['tables'][str(table_idx)]['columns'][i] = {'sdtype': 'numerical'}
metadata_dict['tables']['0']['columns'][1] = {'sdtype': 'id'}
metadata_dict['tables']['1']['columns'][2] = {'sdtype': 'id'}
metadata_dict['relationships'] = [
{
'parent_table_name': '0',
'parent_primary_key': 1,
'child_table_name': '1',
'child_foreign_key': 2
}
]
metadata = MultiTableMetadata.load_from_dict(metadata_dict)
metadata.set_primary_key('0', '1')

# Run
synth = HMASynthesizer(metadata)
synth.fit(data)
first_sample = synth.sample()
second_sample = synth.sample()
assert first_sample['0'].columns.tolist() == data['0'].columns.tolist()
assert first_sample['1'].columns.tolist() == data['1'].columns.tolist()
assert second_sample['0'].columns.tolist() == data['0'].columns.tolist()
assert second_sample['1'].columns.tolist() == data['1'].columns.tolist()
# Assert
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(AssertionError):
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
pd.testing.assert_frame_equal(first_sample['0'], second_sample['0'])
38 changes: 38 additions & 0 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,41 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):
' Total number of columns: 3\n'
' Synthesizer id: GaussianCopulaSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5\n'
)


SYNTHESIZERS_CLASSES = [
pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'),
pytest.param(TVAESynthesizer, id='TVAESynthesizer'),
pytest.param(GaussianCopulaSynthesizer, id='GaussianCopulaSynthesizer'),
pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'),
]


@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_fit_and_sample_numerical_col_names(synthesizer_class):
"""Test fitting/sampling when column names are integers"""
# Setup
num_rows = 50
num_cols = 10
values = {
i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)
}
data = pd.DataFrame(values)
metadata = SingleTableMetadata()
metadata_dict = {'columns': {}}
for i in range(num_cols):
metadata_dict['columns'][i] = {'sdtype': 'numerical'}
metadata = SingleTableMetadata.load_from_dict(metadata_dict)

# Run
synth = synthesizer_class(metadata)
synth.fit(data)
sample_1 = synth.sample(10)
sample_2 = synth.sample(10)

lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
assert sample_1.columns.tolist() == data.columns.tolist()
assert sample_2.columns.tolist() == data.columns.tolist()

# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(sample_1, sample_2)
79 changes: 79 additions & 0 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,85 @@ def test_load_from_dict(self, mock_singletablemetadata):
}
]

@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test_load_from_dict_integer(self, mock_singletablemetadata):
"""Test that ``load_from_dict`` returns a instance of ``MultiTableMetadata``.

Test that when calling the ``load_from_dict`` method a new instance with the passed
python ``dict`` details should be created. Make sure that integers passed in are
turned into strings to ensure metadata is properly typed.

Setup:
- A dict representing a ``MultiTableMetadata``.

Mock:
- Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table``

Output:
- ``instance`` that contains ``instance.tables`` and ``instance.relationships``.

Side Effects:
- ``SingleTableMetadata.load_from_dict`` has been called.
"""
# Setup
multitable_metadata = {
'tables': {
'accounts': {
1: {'sdtype': 'numerical'},
2: {'sdtype': 'numerical'},
'amount': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'owner': {'sdtype': 'id'},
},
'branches': {
1: {'sdtype': 'numerical'},
'name': {'sdtype': 'id'},
}
},
'relationships': [
{
'parent_table_name': 'accounts',
'parent_primary_key': 1,
'child_table_name': 'branches',
'child_foreign_key': 1,
}
]
}

single_table_accounts = {
'1': {'sdtype': 'numerical'},
'2': {'sdtype': 'numerical'},
'amount': {'sdtype': 'numerical'},
'start_date': {'sdtype': 'datetime'},
'owner': {'sdtype': 'id'},
}
single_table_branches = {
'1': {'sdtype': 'numerical'},
'name': {'sdtype': 'id'},
}
mock_singletablemetadata.load_from_dict.side_effect = [
single_table_accounts,
single_table_branches
]

# Run
instance = MultiTableMetadata.load_from_dict(multitable_metadata)

# Assert
assert instance.tables == {
'accounts': single_table_accounts,
'branches': single_table_branches
}

assert instance.relationships == [
{
'parent_table_name': 'accounts',
'parent_primary_key': '1',
'child_table_name': 'branches',
'child_foreign_key': '1',
}
]

@patch('sdv.metadata.multi_table.json')
def test___repr__(self, mock_json):
"""Test that the ``__repr__`` method.
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,6 +2695,34 @@ def test_load_from_dict(self):
assert instance.sequence_index is None
assert instance._version == 'SINGLE_TABLE_V1'

def test_load_from_dict_integer(self):
"""Test that ``load_from_dict`` returns a instance with the ``dict`` updated objects.

If the metadata dict contains columns with integers for certain reasons
(e.g. due to missing column names from CSV) make sure they are correctly typed
to strings to ensure metadata is parsed properly.
"""
# Setup
my_metadata = {
'columns': {1: 'value'},
'primary_key': 'pk',
'alternate_keys': [],
'sequence_key': None,
'sequence_index': None,
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'
}

# Run
instance = SingleTableMetadata.load_from_dict(my_metadata)

# Assert
assert instance.columns == {'1': 'value'}
assert instance.primary_key == 'pk'
assert instance.sequence_key is None
assert instance.alternate_keys == []
assert instance.sequence_index is None
assert instance._version == 'SINGLE_TABLE_V1'

@patch('sdv.metadata.utils.Path')
def test_load_from_json_path_does_not_exist(self, mock_path):
"""Test the ``load_from_json`` method.
Expand Down
Loading
Loading