diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 006f6636e..5565428f9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -14,7 +14,6 @@ from sdv import version from sdv._utils import ( - _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, generate_synthesizer_id, @@ -457,7 +456,6 @@ def fit(self, data): }) check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt) - _validate_foreign_keys_not_null(self.metadata, data) self._check_metadata_updated() self._fitted = False processed_data = self.preprocess(data) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index a774f51c4..8482c5ed4 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1316,7 +1316,7 @@ def test_metadata_updated_warning_detect(self): assert len(record) == 1 def test_null_foreign_keys(self): - """Test that the synthesizer crashes when there are null foreign keys.""" + """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup metadata = MultiTableMetadata() metadata.add_table('parent_table') @@ -1372,16 +1372,7 @@ def test_null_foreign_keys(self): metadata.validate_data(data) # Run and Assert - err_msg = re.escape( - 'The data contains null values in foreign key columns. This feature is currently ' - 'unsupported. Please remove null values to fit the synthesizer.\n' - '\n' - 'Affected columns:\n' - "Table 'child_table1', column(s) ['fk']\n" - "Table 'child_table2', column(s) ['fk1', 'fk2']\n" - ) - with pytest.raises(SynthesizerInputError, match=err_msg): - synthesizer.fit(data) + synthesizer.fit(data) def test_sampling_with_unknown_sdtype_numerical_column(self): """Test that if a numerical column is detected as unknown in the metadata, diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 63e363f1b..f732fe827 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1067,8 +1067,7 @@ def test_fit_processed_data_raises_version_error(self): instance._check_metadata_updated.assert_not_called() @patch('sdv.multi_table.base.datetime') - @patch('sdv.multi_table.base._validate_foreign_keys_not_null') - def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): + def test_fit(self, mock_datetime, caplog): """Test that it calls the appropriate methods.""" # Setup mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' @@ -1088,7 +1087,6 @@ def test_fit(self, mock_validate_foreign_keys_not_null, mock_datetime, caplog): BaseMultiTableSynthesizer.fit(instance, data) # Assert - mock_validate_foreign_keys_not_null.assert_called_once_with(instance.metadata, data) instance.preprocess.assert_called_once_with(data) instance.fit_processed_data.assert_called_once_with(instance.preprocess.return_value) instance._check_metadata_updated.assert_called_once()