From 8de3f32687b43bd6c692ff5e5ae9cde20d6cee71 Mon Sep 17 00:00:00 2001 From: rwedge Date: Tue, 18 Jun 2024 10:47:40 -0400 Subject: [PATCH] remove no-null fk validation check; update tests --- sdv/multi_table/base.py | 2 -- tests/integration/multi_table/test_hma.py | 13 ++----------- tests/unit/multi_table/test_base.py | 4 +--- 3 files changed, 3 insertions(+), 16 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index c789ee3c5..978ac8c0a 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -14,7 +14,6 @@ from sdv import version from sdv._utils import ( - _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, @@ -448,7 +447,6 @@ def fit(self, data): }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() self._fitted = False processed_data = self.preprocess(data) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index a82049200..7f80f3ab3 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1314,7 +1314,7 @@ def test_metadata_updated_warning_detect(self): assert len(record) == 1 def test_null_foreign_keys(self): - """Test that the synthesizer crashes when there are null foreign keys.""" + """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup metadata = MultiTableMetadata() metadata.add_table('parent_table') @@ -1370,16 +1370,7 @@ def test_null_foreign_keys(self): metadata.validate_data(data) # Run and Assert - err_msg = re.escape( - 'The data contains null values in foreign key columns. This feature is currently ' - 'unsupported. Please remove null values to fit the synthesizer.\n' - '\n' - 'Affected columns:\n' - "Table 'child_table1', column(s) ['fk']\n" - "Table 'child_table2', column(s) ['fk1', 'fk2']\n" - ) - with pytest.raises(SynthesizerInputError, match=err_msg): - synthesizer.fit(data) + synthesizer.fit(data) parametrization = [ diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 3c493606f..83a5ccca3 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -981,8 +981,7 @@ def test_fit_processed_data_raises_version_error(self): instance._check_metadata_updated.assert_not_called() @patch('sdv.multi_table.base.datetime') - @patch('sdv.multi_table.base._validate_foreign_keys_not_null') - def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): + def test_fit(self, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' @@ -1002,7 +1001,6 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): BaseMultiTableSynthesizer.fit(instance, data) # Assert - mock_validate_foreign_keys_not_null.assert_called_once_with(instance.metadata, data) instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once()