Skip to content

Commit

Permalink
Add print message
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Nov 7, 2023
1 parent 55668dc commit fe6e4b2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
25 changes: 25 additions & 0 deletions sdv/single_table/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,31 @@ def _estimate_num_columns(self, data):

return num_generated_columns

def _preprocess(self, data):
dict_generated_columns = self._estimate_num_columns(data)
if sum(dict_generated_columns.values()) > 1000:
header = {'Original Column Name ': 'Est # of Columns (CTGAN)'}
dict_generated_columns = {**header, **dict_generated_columns}
longest_column_name = len(max(dict_generated_columns, key=len))
cap = '<' + str(longest_column_name)
lines_to_print = []
for column, num_generated_columns in dict_generated_columns.items():
lines_to_print.append(f'{column:{cap}} {num_generated_columns}')

generated_columns_str = '\n'.join(lines_to_print)
print( # noqa: T001
'PerformanceAlert: Using the CTGANSynthesizer on this data is not recommended. '
'To model this data, CTGAN will generate a large number of columns.'
'\n\n'
f'{generated_columns_str}'
'\n\n'
'We recommend preprocessing discrete columns that can have many values, '
"using 'update_transformers'. Or you may drop columns that are not necessary "
'to model. (Exit this script using ctrl-C)'
)

return super()._preprocess(data)

def _fit(self, processed_data):
"""Fit the model to the table.
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,7 +2014,7 @@ def test_detect_table_from_csv(self, load_csv_mock, single_table_mock, log_mock)
should be created and call the ``detect_from_csv`` method.
Setup:
- Mock the ``SingleTableMetadata`` class and the print function.
- Mock the ``SingleTableMetadata`` class and the logger.
Assert:
- Table should be added to ``self.tables``.
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/single_table/test_ctgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import re
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd

from sdv.metadata.single_table import SingleTableMetadata
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer

Expand Down Expand Up @@ -128,6 +132,37 @@ def test_get_parameters(self):
'cuda': True,
}

def test_preprocessing_many_categories(self, capfd):
"""Test a message is printed during preprocess when a column has many categories."""
# Setup
metadata = SingleTableMetadata()
metadata.add_column('name_longer_than_Original_Column_Name', sdtype='numerical')
metadata.add_column('categorical', sdtype='categorical')
data = pd.DataFrame({
'name_longer_than_Original_Column_Name': np.random.rand(1_001),
'categorical': [f'cat_{i}' for i in range(1_001)],
})
instance = CTGANSynthesizer(metadata)

# Run
instance.auto_assign_transformers(data)
instance.preprocess(data)

# Assert
out, err = capfd.readouterr()
assert out == re.escape(
'PerformanceAlert: Using the CTGANSynthesizer on this data is not recommended. '
'To model this data, CTGAN will generate a large number of columns.'
''
'Original Column Name Est # of Columns (CTGAN)'
'name_longer_than_Original_Column_Name 11'
'categorical 1001'
''
'We recommend preprocessing discrete columns that can have many values, '
"using 'update_transformers'. Or you may drop columns that are not necessary "
'to model. (Exit this script using ctrl-C)'
)

@patch('sdv.single_table.ctgan.CTGAN')
@patch('sdv.single_table.ctgan.detect_discrete_columns')
def test__fit(self, mock_detect_discrete_columns, mock_ctgan):
Expand Down

0 comments on commit fe6e4b2

Please sign in to comment.