Skip to content

Commit

Permalink
Move function drop_unknown_references from poc to be directly und…
Browse files Browse the repository at this point in the history
…er `utils` (#1969)
  • Loading branch information
R-Palazzo authored May 1, 2024
1 parent 2538814 commit b94cf94
Show file tree
Hide file tree
Showing 11 changed files with 604 additions and 501 deletions.
4 changes: 2 additions & 2 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,8 @@ def _validate_foreign_keys(self, data):

errors.append(
f"Error: foreign key column '{relation['child_foreign_key']}' contains "
f'unknown references: {message}. Please use the utility method'
" 'drop_unknown_references' to clean the data."
f'unknown references: {message}. Please use the method'
" 'drop_unknown_references' from sdv.utils to clean the data."
)

if errors:
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
except SynthesizerInputError:
warnings.warn(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils'
' to drop these rows before using ``get_random_subset``.'
)

Expand Down
2 changes: 1 addition & 1 deletion sdv/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Utils module."""

from sdv.utils.poc import drop_unknown_references
from sdv.utils.utils import drop_unknown_references

__all__ = (
'drop_unknown_references',
Expand Down
76 changes: 13 additions & 63 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,22 @@
"""Utility functions."""
import sys
from copy import deepcopy
"""POC functions to use HMASynthesizer succesfully."""
import warnings

import pandas as pd

from sdv._utils import _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SynthesizerInputError
from sdv.errors import InvalidDataError
from sdv.metadata.errors import InvalidMetadataError
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS
from sdv.multi_table.utils import (
_drop_rows, _get_total_estimated_columns, _print_simplified_schema_summary,
_print_subsample_summary, _simplify_data, _simplify_metadata, _subsample_data)


def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True):
"""Drop rows with unknown foreign keys.
_get_total_estimated_columns, _print_simplified_schema_summary, _print_subsample_summary,
_simplify_data, _simplify_metadata, _subsample_data)
from sdv.utils.utils import drop_unknown_references as utils_drop_unknown_references

Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to True.
verbose (bool):
If True, print information about the rows that are dropped.
Defaults to True.

Returns:
dict:
Dictionary with the dataframes ensuring referential integrity.
"""
success_message = 'Success! All foreign keys have referential integrity.'
table_names = sorted(metadata.tables)
summary_table = pd.DataFrame({
'Table Name': table_names,
'# Rows (Original)': [len(data[table]) for table in table_names],
'# Invalid Rows': [0] * len(table_names),
'# Rows (New)': [len(data[table]) for table in table_names]
})
metadata.validate()
try:
metadata.validate_data(data)
if drop_missing_values:
_validate_foreign_keys_not_null(metadata, data)

if verbose:
sys.stdout.write(
'\n'.join([success_message, '', summary_table.to_string(index=False)])
)

return data
except (InvalidDataError, SynthesizerInputError):
result = deepcopy(data)
_drop_rows(result, metadata, drop_missing_values)
if verbose:
summary_table['# Invalid Rows'] = [
len(data[table]) - len(result[table]) for table in table_names
]
summary_table['# Rows (New)'] = [len(result[table]) for table in table_names]
sys.stdout.write('\n'.join([
success_message, '', summary_table.to_string(index=False)
]))

return result
def drop_unknown_references(data, metadata):
"""Wrap the drop_unknown_references function from the utils module."""
warnings.warn(
"Please access the 'drop_unknown_references' function directly from the sdv.utils module"
'instead of sdv.utils.poc.', FutureWarning
)
return utils_drop_unknown_references(data, metadata)


def simplify_schema(data, metadata, verbose=True):
Expand Down
65 changes: 65 additions & 0 deletions sdv/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Utils module."""
import sys
from copy import deepcopy

import pandas as pd

from sdv._utils import _validate_foreign_keys_not_null
from sdv.errors import InvalidDataError, SynthesizerInputError
from sdv.multi_table.utils import _drop_rows


def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=True):
"""Drop rows with unknown foreign keys.
Args:
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
Metadata of the datasets.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
If True, drop rows with missing values in the foreign keys.
Defaults to True.
verbose (bool):
If True, print information about the rows that are dropped.
Defaults to True.
Returns:
dict:
Dictionary with the dataframes ensuring referential integrity.
"""
success_message = 'Success! All foreign keys have referential integrity.'
table_names = sorted(metadata.tables)
summary_table = pd.DataFrame({
'Table Name': table_names,
'# Rows (Original)': [len(data[table]) for table in table_names],
'# Invalid Rows': [0] * len(table_names),
'# Rows (New)': [len(data[table]) for table in table_names]
})
metadata.validate()
try:
metadata.validate_data(data)
if drop_missing_values:
_validate_foreign_keys_not_null(metadata, data)

if verbose:
sys.stdout.write(
'\n'.join([success_message, '', summary_table.to_string(index=False)])
)

return data
except (InvalidDataError, SynthesizerInputError):
result = deepcopy(data)
_drop_rows(result, metadata, drop_missing_values)
if verbose:
summary_table['# Invalid Rows'] = [
len(data[table]) - len(result[table]) for table in table_names
]
summary_table['# Rows (New)'] = [len(result[table]) for table in table_names]
sys.stdout.write('\n'.join([
success_message, '', summary_table.to_string(index=False)
]))

return result
97 changes: 2 additions & 95 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import pytest

from sdv.datasets.demo import download_demo
from sdv.errors import InvalidDataError
from sdv.metadata import MultiTableMetadata
from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS, HMASynthesizer
from sdv.multi_table.utils import _get_total_estimated_columns
from sdv.utils.poc import drop_unknown_references, get_random_subset, simplify_schema
from sdv.utils.poc import get_random_subset, simplify_schema


@pytest.fixture
Expand Down Expand Up @@ -65,98 +64,6 @@ def data():
}


def test_drop_unknown_references(metadata, data, capsys):
"""Test ``drop_unknown_references`` end to end."""
# Run
expected_message = re.escape(
'The provided data does not match the metadata:\n'
'Relationships:\n'
"Error: foreign key column 'parent_id' contains unknown references: (5)"
". Please use the utility method 'drop_unknown_references' to clean the data."
)
with pytest.raises(InvalidDataError, match=expected_message):
metadata.validate_data(data)

cleaned_data = drop_unknown_references(data, metadata)
metadata.validate_data(cleaned_data)
captured = capsys.readouterr()

# Assert
pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent'])
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert len(cleaned_data['child']) == 4
expected_output = (
'Success! All foreign keys have referential integrity.\n\n'
'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n'
' child 5 1 4\n'
' parent 5 0 5'
)
assert captured.out.strip() == expected_output


def test_drop_unknown_references_valid_data(metadata, data, capsys):
"""Test ``drop_unknown_references`` when data has referential integrity."""
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = 2

# Run
result = drop_unknown_references(data, metadata)
captured = capsys.readouterr()

# Assert
pd.testing.assert_frame_equal(result['parent'], data['parent'])
pd.testing.assert_frame_equal(result['child'], data['child'])
expected_message = (
'Success! All foreign keys have referential integrity.\n\n'
'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n'
' child 5 0 5\n'
' parent 5 0 5'
)
assert captured.out.strip() == expected_message


def test_drop_unknown_references_drop_missing_values(metadata, data, capsys):
"""Test ``drop_unknown_references`` when there is missing values in the foreign keys."""
# Setup
data = deepcopy(data)
data['child'].loc[4, 'parent_id'] = np.nan

# Run
cleaned_data = drop_unknown_references(data, metadata)
metadata.validate_data(cleaned_data)
captured = capsys.readouterr()

# Assert
pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent'])
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert len(cleaned_data['child']) == 4
expected_output = (
'Success! All foreign keys have referential integrity.\n\n'
'Table Name # Rows (Original) # Invalid Rows # Rows (New)\n'
' child 5 1 4\n'
' parent 5 0 5'
)
assert captured.out.strip() == expected_output


def test_drop_unknown_references_not_drop_missing_values(metadata, data):
"""Test ``drop_unknown_references`` when the missing values in the foreign keys are kept."""
# Setup
data['child'].loc[3, 'parent_id'] = np.nan

# Run
cleaned_data = drop_unknown_references(
data, metadata, drop_missing_values=False, verbose=False
)

# Assert
pd.testing.assert_frame_equal(cleaned_data['parent'], data['parent'])
pd.testing.assert_frame_equal(cleaned_data['child'], data['child'].iloc[:4])
assert pd.isna(cleaned_data['child']['parent_id']).any()
assert len(cleaned_data['child']) == 4


def test_simplify_schema(capsys):
"""Test ``simplify_schema`` end to end."""
# Setup
Expand Down Expand Up @@ -337,7 +244,7 @@ def test_get_random_subset_with_missing_values(metadata, data):
data['child'].loc[4, 'parent_id'] = np.nan
expected_warning = re.escape(
'The data contains null values in foreign key columns. '
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils.poc'
'We recommend using ``drop_unknown_foreign_keys`` method from sdv.utils'
' to drop these rows before using ``get_random_subset``.'
)

Expand Down
Loading

0 comments on commit b94cf94

Please sign in to comment.