Skip to content

Commit

Permalink
Remove input error on null foreign keys from BaseMultiTableSynthesize…
Browse files Browse the repository at this point in the history
…r.fit (#2077)
  • Loading branch information
rwedge committed Aug 20, 2024
1 parent 57935f5 commit b07d2e1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 16 deletions.
2 changes: 0 additions & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from sdv import version
from sdv._utils import (
_validate_foreign_keys_not_null,
check_sdv_versions_and_warn,
check_synthesizer_version,
generate_synthesizer_id,
Expand Down Expand Up @@ -457,7 +456,6 @@ def fit(self, data):
})

check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
_validate_foreign_keys_not_null(self.metadata, data)
self._check_metadata_updated()
self._fitted = False
processed_data = self.preprocess(data)
Expand Down
13 changes: 2 additions & 11 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ def test_metadata_updated_warning_detect(self):
assert len(record) == 1

def test_null_foreign_keys(self):
"""Test that the synthesizer crashes when there are null foreign keys."""
"""Test that the synthesizer does not crash when there are null foreign keys."""
# Setup
metadata = MultiTableMetadata()
metadata.add_table('parent_table')
Expand Down Expand Up @@ -1372,16 +1372,7 @@ def test_null_foreign_keys(self):
metadata.validate_data(data)

# Run and Assert
err_msg = re.escape(
'The data contains null values in foreign key columns. This feature is currently '
'unsupported. Please remove null values to fit the synthesizer.\n'
'\n'
'Affected columns:\n'
"Table 'child_table1', column(s) ['fk']\n"
"Table 'child_table2', column(s) ['fk1', 'fk2']\n"
)
with pytest.raises(SynthesizerInputError, match=err_msg):
synthesizer.fit(data)
synthesizer.fit(data)

def test_sampling_with_unknown_sdtype_numerical_column(self):
"""Test that if a numerical column is detected as unknown in the metadata,
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,8 +1067,7 @@ def test_fit_processed_data_raises_version_error(self):
instance._check_metadata_updated.assert_not_called()

@patch('sdv.multi_table.base.datetime')
@patch('sdv.multi_table.base._validate_foreign_keys_not_null')
def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog):
def test_fit(self, mock_datetime, caplog):
"""Test that it calls the appropriate methods."""
# Setup
mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183'
Expand All @@ -1088,7 +1087,6 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog):
BaseMultiTableSynthesizer.fit(instance, data)

# Assert
mock_validate_foreign_keys_not_null.assert_called_once_with(instance.metadata, data)
instance.preprocess.assert_called_once_with(data)
instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value)
instance._check_metadata_updated.assert_called_once()
Expand Down

0 comments on commit b07d2e1

Please sign in to comment.