Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Dec 16, 2024
1 parent 932ac68 commit ee17238
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 3 deletions.
10 changes: 7 additions & 3 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,13 +1649,13 @@ def test_large_integer_ids_overflow_three_tables(self):
'col_0': [1, 2, 3],
})
table_1 = pd.DataFrame({
'col_1': [9999999999999999990, 9999999999999999991, 9999999999999999992], # len 19
'col_1': [1, 2, 3],
'col_3': [7, 8, 9],
'col_2': [4, 5, 6],
'col_0': [1, 2, 2],
})
table_2 = pd.DataFrame({
'col_A': [9999999999999999990, 9999999999999999990, 9999999999999999991], # len 19
'col_A': [1, 2, 3],
'col_B': ['d', 'e', 'f'],
'col_C': ['g', 'h', 'i'],
})
Expand Down Expand Up @@ -1725,6 +1725,10 @@ def test_large_integer_ids_overflow_three_tables(self):
)

# Check relationships are preserved
child_fks = set(synthetic_data['table_1']['col_0'])
parent_pks = set(synthetic_data['table_0']['col_0'])
assert child_fks.issubset(parent_pks), 'Foreign key constraint violated'

child_fks = set(synthetic_data['table_2']['col_A'])
parent_pks = set(synthetic_data['table_1']['col_1'])
assert child_fks.issubset(parent_pks), 'Foreign key constraint violated'
Expand All @@ -1737,7 +1741,7 @@ def test_large_integer_ids_overflow_three_tables(self):
'please check your input data and metadata settings.'
)
assert str(captured_warnings[1].message) == (
"The real data in 'table_1' and column 'col_1' was stored as 'uint64' but the "
"The real data in 'table_1' and column 'col_1' was stored as 'int64' but the "
'synthetic data overflowed when casting back to this type. If this is a problem, '
'please check your input data and metadata settings.'
)
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,6 +2357,31 @@ def test_reverse_transform(self):
})
pd.testing.assert_frame_equal(reverse_transformed, expected_output)

def test_reverse_transform_overflow(self):
"""Test it raises a warning when the reverse transform overflows."""
# Setup
data = pd.DataFrame({
'col': [99999999999999999990, 99999999999999999991, 99999999999999999992]
})
dp = DataProcessor(SingleTableMetadata())
dp._dtypes = {'col': 'int64'}
dp.metadata = Mock()
dp.metadata.columns = {'col': None}
dp.fitted = True
dp._hyper_transformer = Mock()
dp._hyper_transformer.reverse_transform_subset.return_value = data
dp._hyper_transformer._output_columns = ['col']
dp.table_name = 'table_name'

# Run
warn_msg = (
"The real data in 'table_name' and column 'col' was stored as 'int64' but the "
'synthetic data overflowed when casting back to this type. If this is a problem, '
'please check your input data and metadata settings.'
)
with pytest.warns(UserWarning, match=warn_msg):
dp.reverse_transform(data)

@patch('sdv.data_processing.data_processor.LOGGER')
def test_reverse_transform_hyper_transformer_errors(self, log_mock):
"""Test the ``reverse_transform`` method.
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/sampling/test_independent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,32 @@ def test__finalize_id_being_string(self, mock_logger):
)
assert mock_logger.info.call_args_list == [call(message_users), call(message_sessions)]

@patch('sdv.sampling.independent_sampler.LOGGER')
def test__finalize_overflow(self, mock_logger):
"""Test it logs a warning when the synthetic data overflows."""
# Setup
sampled_data = {
'table': pd.DataFrame({
'id': [99999999999999999990, 99999999999999999991, 99999999999999999992]
}),
}
parent_synthesizer = Mock()
parent_synthesizer._data_processor._dtypes = {'id': 'int64'}
instance = Mock()
instance._table_synthesizers = {'table': parent_synthesizer}

# Run
BaseIndependentSampler._finalize(instance, sampled_data)

# Assert
message_users = (
"The real data in 'table' and column 'id' was stored as "
"'int64' but the synthetic data overflowed when casting back to "
'this type. If this is a problem, please check your input data '
'and metadata settings.'
)
assert mock_logger.debug.call_args_list == [call(message_users)]

def test__sample(self):
"""Test that the ``_sample_table`` is called for root tables."""
# Setup
Expand Down

0 comments on commit ee17238

Please sign in to comment.