diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index 682895bad..360fcc164 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -16,14 +16,14 @@ from sdv.utils.utils import drop_unknown_references as utils_drop_unknown_references -def drop_unknown_references(data, metadata): +def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=True): """Wrap the drop_unknown_references function from the utils module.""" warnings.warn( "Please access the 'drop_unknown_references' function directly from the sdv.utils module" 'instead of sdv.utils.poc.', FutureWarning, ) - return utils_drop_unknown_references(data, metadata) + return utils_drop_unknown_references(data, metadata, drop_missing_values, verbose) def simplify_schema(data, metadata, verbose=True): diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index f6e5db7c0..2c4b6b6ae 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -10,7 +10,7 @@ from sdv.multi_table.utils import _drop_rows -def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True): +def drop_unknown_references(data, metadata, drop_missing_values=False, verbose=True): """Drop rows with unknown foreign keys. Args: @@ -22,7 +22,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr drop_missing_values (bool): Boolean describing whether or not to also drop foreign keys with missing values If True, drop rows with missing values in the foreign keys. - Defaults to True. + Defaults to False. verbose (bool): If True, print information about the rows that are dropped. Defaults to True. diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 5139ab0c1..0405703a8 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -110,7 +110,7 @@ def test_drop_unknown_references_drop_missing_values(metadata, data, capsys): data['child'].loc[4, 'parent_id'] = np.nan # Run - cleaned_data = drop_unknown_references(data, metadata) + cleaned_data = drop_unknown_references(data, metadata, drop_missing_values=True) metadata.validate_data(cleaned_data) captured = capsys.readouterr() diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index bbd9723c3..c8873e2c9 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -17,6 +17,8 @@ def test_drop_unknown_references(mock_drop_unknown_references): # Setup data = Mock() metadata = Mock() + drop_missing_values = Mock() + verbose = Mock() expected_message = re.escape( "Please access the 'drop_unknown_references' function directly from the sdv.utils module" 'instead of sdv.utils.poc.' @@ -24,10 +26,12 @@ def test_drop_unknown_references(mock_drop_unknown_references): # Run with pytest.warns(FutureWarning, match=expected_message): - drop_unknown_references(data, metadata) + drop_unknown_references(data, metadata, drop_missing_values, verbose) # Assert - mock_drop_unknown_references.assert_called_once_with(data, metadata) + mock_drop_unknown_references.assert_called_once_with( + data, metadata, drop_missing_values, verbose + ) @patch('sdv.utils.poc._get_total_estimated_columns') diff --git a/tests/unit/utils/test_utils.py b/tests/unit/utils/test_utils.py index 7aad95693..a3a2d810c 100644 --- a/tests/unit/utils/test_utils.py +++ b/tests/unit/utils/test_utils.py @@ -87,7 +87,7 @@ def _drop_rows(data, metadata, drop_missing_values): } metadata.validate.assert_called_once() metadata.validate_data.assert_called_once_with(data) - mock_drop_rows.assert_called_once_with(result, metadata, True) + mock_drop_rows.assert_called_once_with(result, metadata, False) for table_name, table in result.items(): pd.testing.assert_frame_equal(table, expected_result[table_name]) @@ -189,7 +189,7 @@ def test_drop_unknown_references_with_nan(mock_validate_foreign_keys, mock_get_r mock_get_rows_to_drop.return_value = defaultdict(set, {'child': {4}, 'grandchild': {0, 3, 4}}) # Run - result = drop_unknown_references(data, metadata, verbose=False) + result = drop_unknown_references(data, metadata, drop_missing_values=True, verbose=False) # Assert metadata.validate.assert_called_once()