Skip to content

Commit

Permalink
Add integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Dec 13, 2024
1 parent 15ed1a1 commit 932ac68
Showing 1 changed file with 258 additions and 1 deletion.
259 changes: 258 additions & 1 deletion tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,7 @@ def test_large_integer_ids_overflow(self):
}

# Run
synthesizer = HMASynthesizer(metadata, verbose=False)
synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)
with warnings.catch_warnings(record=True) as captured_warnings:
synthetic_data = synthesizer.sample()
Expand Down Expand Up @@ -1635,6 +1635,263 @@ def test_large_integer_ids_overflow(self):
'please check your input data and metadata settings.'
)

def test_large_integer_ids_overflow_three_tables(self):
"""Test that it overflows.
When the real data primary key can fit in int64, ie has less than 19 digits,
but the regex_format specifies data that can't fit in int64, ie over 20 digits,
the synthetic data will raise an overflow warning and it will stay as object dtype.
This should raise two warnings, one for each parent table with ids that overflow.
"""
# Setup
table_0 = pd.DataFrame({
'col_0': [1, 2, 3],
})
table_1 = pd.DataFrame({
'col_1': [9999999999999999990, 9999999999999999991, 9999999999999999992], # len 19
'col_3': [7, 8, 9],
'col_2': [4, 5, 6],
'col_0': [1, 2, 2],
})
table_2 = pd.DataFrame({
'col_A': [9999999999999999990, 9999999999999999990, 9999999999999999991], # len 19
'col_B': ['d', 'e', 'f'],
'col_C': ['g', 'h', 'i'],
})
metadata = Metadata.load_from_dict({
'tables': {
'table_0': {
'columns': {
'col_0': {'sdtype': 'id', 'regex_format': '[1-9]{20}'},
},
'primary_key': 'col_0',
},
'table_1': {
'columns': {
'col_1': {'sdtype': 'id', 'regex_format': '[1-9]{20}'},
'col_2': {'sdtype': 'numerical'},
'col_3': {'sdtype': 'numerical'},
'col_0': {'sdtype': 'id', 'regex_format': '[1-9]{20}'},
},
'primary_key': 'col_1',
},
'table_2': {
'columns': {
'col_A': {'sdtype': 'id', 'regex_format': '[1-9]{20}'},
'col_B': {'sdtype': 'categorical'},
'col_C': {'sdtype': 'categorical'},
},
},
},
'relationships': [
{
'parent_table_name': 'table_1',
'child_table_name': 'table_2',
'parent_primary_key': 'col_1',
'child_foreign_key': 'col_A',
},
{
'parent_table_name': 'table_0',
'child_table_name': 'table_1',
'parent_primary_key': 'col_0',
'child_foreign_key': 'col_0',
},
],
})
data = {
'table_0': table_0,
'table_1': table_1,
'table_2': table_2,
}

# Run
synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)
with warnings.catch_warnings(record=True) as captured_warnings:
synthetic_data = synthesizer.sample()

# Assert
# Check that IDs match the regex pattern
for table_name, table in synthetic_data.items():
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 20 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
parent_pks = set(synthetic_data['table_1']['col_1'])
assert child_fks.issubset(parent_pks), 'Foreign key constraint violated'

# Check that a warning is raised
assert len(captured_warnings) == 2
assert str(captured_warnings[0].message) == (
"The real data in 'table_0' and column 'col_0' was stored as 'int64' but the "
'synthetic data overflowed when casting back to this type. If this is a problem, '
'please check your input data and metadata settings.'
)
assert str(captured_warnings[1].message) == (
"The real data in 'table_1' and column 'col_1' was stored as 'uint64' but the "
'synthetic data overflowed when casting back to this type. If this is a problem, '
'please check your input data and metadata settings.'
)

def test_ids_that_dont_fit_in_int64(self):
"""Test it when both real and synthetic data don't fit in int64."""
# Setup
table_1 = pd.DataFrame({
'col_1': [99999999999999999990, 99999999999999999991, 99999999999999999992], # len 20
'col_3': [7, 8, 9],
'col_2': [4, 5, 6],
})
table_2 = pd.DataFrame({
'col_A': [99999999999999999990, 99999999999999999990, 99999999999999999991], # len 20
'col_B': ['d', 'e', 'f'],
'col_C': ['g', 'h', 'i'],
})
metadata = Metadata.load_from_dict({
'tables': {
'table_1': {
'columns': {
'col_1': {'sdtype': 'id', 'regex_format': '[1-9]{20}'},
'col_2': {'sdtype': 'numerical'},
'col_3': {'sdtype': 'numerical'},
},
'primary_key': 'col_1',
},
'table_2': {
'columns': {
'col_A': {'sdtype': 'id', 'regex_format': '[1-9]{20}'},
'col_B': {'sdtype': 'categorical'},
'col_C': {'sdtype': 'categorical'},
},
},
},
'relationships': [
{
'parent_table_name': 'table_1',
'child_table_name': 'table_2',
'parent_primary_key': 'col_1',
'child_foreign_key': 'col_A',
}
],
})
data = {
'table_1': table_1,
'table_2': table_2,
}
synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)
with warnings.catch_warnings(record=True) as captured_warnings:
synthetic_data = synthesizer.sample()

# Assert
# Check that IDs match the regex pattern
for table_name, table in synthetic_data.items():
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 20 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
parent_pks = set(synthetic_data['table_1']['col_1'])
assert child_fks.issubset(parent_pks), 'Foreign key constraint violated'

# No warnings should be raised
assert len(captured_warnings) == 0

# Check that the diagnostic report is 1.0
report = DiagnosticReport()
report.generate(data, synthetic_data, metadata.to_dict(), verbose=False)
assert report.get_score() == 1.0

def test_large_real_ids_small_synthetic_ids(self):
"""Test it when real data has more digits than synthetic data."""
# Setup
table_1 = pd.DataFrame({
'col_1': [99999999999999999990, 99999999999999999991, 99999999999999999992], # len 20
'col_3': [7, 8, 9],
'col_2': [4, 5, 6],
})
table_2 = pd.DataFrame({
'col_A': [99999999999999999990, 99999999999999999990, 99999999999999999991], # len 20
'col_B': ['d', 'e', 'f'],
'col_C': ['g', 'h', 'i'],
})
metadata = Metadata.load_from_dict({
'tables': {
'table_1': {
'columns': {
'col_1': {'sdtype': 'id', 'regex_format': '[1-9]{1}'},
'col_2': {'sdtype': 'numerical'},
'col_3': {'sdtype': 'numerical'},
},
'primary_key': 'col_1',
},
'table_2': {
'columns': {
'col_A': {'sdtype': 'id', 'regex_format': '[1-9]{1}'},
'col_B': {'sdtype': 'categorical'},
'col_C': {'sdtype': 'categorical'},
},
},
},
'relationships': [
{
'parent_table_name': 'table_1',
'child_table_name': 'table_2',
'parent_primary_key': 'col_1',
'child_foreign_key': 'col_A',
}
],
})
data = {
'table_1': table_1,
'table_2': table_2,
}
synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)
with warnings.catch_warnings(record=True) as captured_warnings:
synthetic_data = synthesizer.sample()

# Assert
# Check that IDs match the regex pattern
for table_name, table in synthetic_data.items():
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 1 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
parent_pks = set(synthetic_data['table_1']['col_1'])
assert child_fks.issubset(parent_pks), 'Foreign key constraint violated'

# No warnings should be raised
assert len(captured_warnings) == 0

# Check that the diagnostic report is 1.0
report = DiagnosticReport()
report.generate(data, synthetic_data, metadata.to_dict(), verbose=False)
assert report.get_score() == 1.0


@pytest.mark.parametrize('num_rows', [(10), (1000)])
def test_hma_0_1_child(num_rows):
Expand Down

0 comments on commit 932ac68

Please sign in to comment.