Skip to content

Commit

Permalink
Debug
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed May 3, 2024
1 parent 79b6853 commit 099dad9
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 19 deletions.
12 changes: 11 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True,
enforce_min_max_values=self.enforce_min_max_values,
locales=self.locales,
)
self._original_columns = pd.Index([])
self._fitted = False
self._random_state_set = False
self._update_default_transformers()
Expand Down Expand Up @@ -362,6 +363,11 @@ def get_info(self):
return info

def _preprocess(self, data):
# for column in data.columns:
# if isinstance(column, int):
# self._original_columns = data.columns
# data.columns = data.columns.astype(str)
# break
self.validate(data)
self._data_processor.fit(data)
return self._data_processor.transform(data)
Expand Down Expand Up @@ -448,14 +454,15 @@ def fit(self, data):
len(data.columns),
self._synthesizer_id,
)
data.columns = data.columns.astype(str)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
self._check_metadata_updated()
self._fitted = False
self._data_processor.reset_sampling()
self._random_state_set = False
processed_data = self._preprocess(data)
self.fit_processed_data(processed_data)
if not self._original_columns.empty:
data.columns = self._original_columns

def save(self, filepath):
"""Save this model instance to the given path using cloudpickle.
Expand Down Expand Up @@ -884,6 +891,9 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
show_progress_bar=show_progress_bar
)

if not self._original_columns.empty:
sampled_data.columns = self._original_columns

SYNTHESIZER_LOGGER.info(
'\nSample:\n'
' Timestamp: %s\n'
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def test_prepare_for_fitting(self):
'degree_perc': FloatFormatter
}
for column_name, transformer_class in expected_transformers.items():
print(f'Transformer Class: {transformer_class}, {column_name}={field_transformers[column_name]}')
if transformer_class is not None:
print("Before Check")
assert isinstance(field_transformers[column_name], transformer_class)
else:
assert field_transformers[column_name] is None
Expand Down
5 changes: 4 additions & 1 deletion tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,7 +1741,7 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):


def test_fit_and_sample_numerical_col_names():
"""Test fit and sampling when column names are integers"""
"""Test fitting/sampling when column names are integers"""
# Setup data
num_rows = 50
num_cols = 10
Expand Down Expand Up @@ -1780,6 +1780,9 @@ def test_fit_and_sample_numerical_col_names():
synth.fit(data)
first_sample = synth.sample()
second_sample = synth.sample()
assert first_sample.columns.tolist() == data.columns.tolist()
assert second_sample.columns.tolist() == data.columns.tolist()

# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(first_sample['0'], second_sample['0'])
37 changes: 20 additions & 17 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,16 @@ def test_synthesizer_logger(mock_datetime, mock_generate_id):
)


def test_fit_and_sample_numerical_col_names():
"""Test fit and sampling when column names are integers"""
SYNTHESIZERS_CLASSES = [
pytest.param(CTGANSynthesizer, id='CTGANSynthesizer'),
pytest.param(TVAESynthesizer, id='TVAESynthesizer'),
pytest.param(GaussianCopulaSynthesizer, id='GaussianCopulaSynthesizer'),
pytest.param(CopulaGANSynthesizer, id='CopulaGANSynthesizer'),
]

@pytest.mark.parametrize('synthesizer_class', SYNTHESIZERS_CLASSES)
def test_fit_and_sample_numerical_col_names(synthesizer_class):
"""Test fitting/sampling when column names are integers"""
# Setup
num_rows = 50
num_cols = 10
Expand All @@ -873,19 +881,14 @@ def test_fit_and_sample_numerical_col_names():
metadata = SingleTableMetadata.load_from_dict(metadata_dict)

# Run
synth = synthesizer_class(metadata)
synth.fit(data)
sample_1 = synth.sample(10)
sample_2 = synth.sample(10)

synthesizers = [
CTGANSynthesizer,
TVAESynthesizer,
GaussianCopulaSynthesizer,
CopulaGANSynthesizer
]
for synthesizer_class in synthesizers:
synth = synthesizer_class(metadata)
synth.fit(data)
sample_1 = synth.sample(10)
sample_2 = synth.sample(10)

# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(sample_1, sample_2)
assert sample_1.columns.tolist() == data.columns.tolist()
assert sample_2.columns.tolist() == data.columns.tolist()

# Assert
with pytest.raises(AssertionError):
pd.testing.assert_frame_equal(sample_1, sample_2)

0 comments on commit 099dad9

Please sign in to comment.