Skip to content

Commit

Permalink
WIP trying to convert everything to string before processing
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed May 1, 2024
1 parent 060bae9 commit 9fa8fac
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
1 change: 1 addition & 0 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ def reverse_transform(self, data, reset_keys=False):
for column in self.metadata.columns.keys() - set(sampled_columns + self._keys)
if self._hyper_transformer.field_transformers.get(column)
]
print(missing_columns)
if missing_columns and num_rows:
anonymized_data = self._hyper_transformer.create_anonymized_columns(
num_rows=num_rows,
Expand Down
5 changes: 3 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,10 +1025,11 @@ def _set_metadata_dict(self, metadata):
Python dictionary representing a ``MultiTableMetadata`` object.
"""
for table_name, table_dict in metadata.get('tables', {}).items():
self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict)
self.tables[str(table_name)] = SingleTableMetadata.load_from_dict(table_dict)

for relationship in metadata.get('relationships', []):
self.relationships.append(relationship)
type_safe_relationships = {key: str(value) if not isinstance(value, str) else value for key, value in relationship.items()}
self.relationships.append(type_safe_relationships)

@classmethod
def load_from_dict(cls, metadata_dict):
Expand Down
3 changes: 2 additions & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,8 @@ def load_from_dict(cls, metadata_dict):
for key in instance._KEYS:
value = deepcopy(metadata_dict.get(key))
if value:
setattr(instance, f'{key}', value)
type_safe_value = {str(key) if not isinstance(key, str) else key: value for key, value in value.items()}
setattr(instance, f'{key}', type_safe_value)

return instance

Expand Down
11 changes: 8 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,14 @@ def fit(self, data):
Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format
(before any transformations).
"""
type_safe_data = {str(key) if not isinstance(key, str) else key: value for key, value in data.items()}
total_rows = 0
total_columns = 0
for table in data.values():
for table, dataframe in type_safe_data.items():
dataframe.columns = dataframe.columns.astype(str)
type_safe_data[table] = dataframe

for table in type_safe_data.values():
total_rows += len(table)
total_columns += len(table.columns)

Expand All @@ -440,10 +445,10 @@ def fit(self, data):
self._synthesizer_id,
)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
_validate_foreign_keys_not_null(self.metadata, data)
_validate_foreign_keys_not_null(self.metadata, type_safe_data)
self._check_metadata_updated()
self._fitted = False
processed_data = self.preprocess(data)
processed_data = self.preprocess(type_safe_data)
self._print(text='\n', end='')
self.fit_processed_data(processed_data)

Expand Down
2 changes: 1 addition & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def fit(self, data):
len(data.columns),
self._synthesizer_id,
)

data.columns = data.columns.astype(str)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
self._check_metadata_updated()
self._fitted = False
Expand Down

0 comments on commit 9fa8fac

Please sign in to comment.