Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for disconnected tables #1979

Merged
merged 10 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,6 @@ def _detect_relationships(self):
sdtype=original_foreign_key_sdtype)
continue

try:
self._validate_all_tables_connected(self._get_parent_map(), self._get_child_map())
except InvalidMetadataError as invalid_error:
warning_msg = (
f'Could not automatically add relationships for all tables. {str(invalid_error)}'
)
warnings.warn(warning_msg)

Comment on lines -516 to -523
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to get rid of this warning? While it's not a requirement anymore, it might be nice to still warn that the metadata detection wasn't able to connect all of the tables. @npatki, thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be an interesting one, it seems that checking all tables are connected is the way we confirm we got all the relationships but now this wouldn't be the case anymore as disjointed tables exist. Not sure how we then go about confirming all relationships are captured. @frances-h

def detect_table_from_dataframe(self, table_name, data):
"""Detect the metadata for a table from a dataframe.
Expand Down Expand Up @@ -739,14 +731,10 @@ def validate(self):
for relation in self.relationships:
self._append_relationships_errors(errors, self._validate_relationship, **relation)

parent_map = self._get_parent_map()
child_map = self._get_child_map()

self._append_relationships_errors(
errors, self._validate_child_map_circular_relationship, child_map)
self._append_relationships_errors(
errors, self._validate_all_tables_connected, parent_map, child_map)

if errors:
raise InvalidMetadataError(
'The metadata is not valid' + '\n'.join(str(e) for e in errors)
Expand Down
21 changes: 21 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,3 +1818,24 @@ def test_table_name_logging(caplog):
# Assert
for msg in caplog.messages:
assert 'table parent_data' in msg or 'table child_data' in msg


def test_disjointed_tables():
"""Test to see if synthesizer works with disjointed tables."""
# Setup
real_data, metadata = download_demo('multi_table', 'Bupa_v1')

# Delete Some Relationships to make it disjointed
remove_some_dict = metadata.to_dict()
half_list = remove_some_dict['relationships'][1::2]
remove_some_dict['relationships'] = half_list
disjoined_metadata = MultiTableMetadata.load_from_dict(remove_some_dict)

# Run
disjoin_synthesizer = HMASynthesizer(disjoined_metadata)
disjoin_synthesizer.fit(real_data)
disjoin_synthetic_data = disjoin_synthesizer.sample(1.0)

# Assert
for table in real_data:
assert list(real_data[table].columns) == list(disjoin_synthetic_data[table].columns)
70 changes: 30 additions & 40 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,14 +1203,42 @@ def test_validate_raises_errors(self):
"Please use 'set_primary_key' in order to set one."
"\nRelationship between tables ('sessions', 'transactions') is invalid. "
'The primary and foreign key columns are not the same type.'
"\nThe relationships in the dataset are disjointed. Table ['payments'] "
'is not connected to any of the other tables.'
)

# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg):
instance.validate()

def test__validate_all_tables_connected_raises_errors(self):
"""Test the method ``_validate_all_tables_connected``.
Test that when a disjointed table is validated with `_validate_all_tables_connected`
Setup:
- Instance of ``MultiTableMetadata`` with all valid tables and
missing relationships.
"""
# Setup
instance = self.get_metadata()
instance.tables['users'].primary_key = None
instance.tables['transactions'].columns['session_id']['sdtype'] = 'datetime'
instance.tables['payments'].columns['date']['sdtype'] = 'id'
instance.tables['payments'].columns['date']['regex_format'] = '[A-z{'
instance.relationships.pop(-1)

# Run
error_msg = re.escape(
'The relationships in the dataset are disjointed. '
"Table ['payments'] is not connected to any of the other tables."
)

# Run and Assert
with pytest.raises(InvalidMetadataError, match=error_msg):
instance._validate_all_tables_connected(
lajohn4747 marked this conversation as resolved.
Show resolved Hide resolved
instance._get_parent_map(),
instance._get_child_map()
)

def test_validate_child_key_is_primary_key(self):
"""Test it crashes if the child key is a primary key."""
# Setup
Expand Down Expand Up @@ -2324,44 +2352,6 @@ def test__detect_relationships(self):
assert instance.relationships == expected_relationships
assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'id'

@patch('sdv.metadata.multi_table.warnings')
def test__detect_relationships_disconnected_warning(self, warnings_mock):
"""Test that ``_detect_relationships`` warns about tables it could not connect."""
# Setup
parent_table = Mock()
parent_table.primary_key = 'id'
parent_table.columns = {
'id': {'sdtype': 'id'},
'user_name': {'sdtype': 'categorical'},
'transactions': {'sdtype': 'numerical'},
}

child_table = SingleTableMetadata()
child_table.primary_key = 'session_id'
child_table.columns = {
'user_id': {'sdtype': 'categorical'},
'session_id': {'sdtype': 'numerical'},
'timestamp': {'sdtype': 'datetime'},
}

instance = MultiTableMetadata()
instance.tables = {
'users': parent_table,
'sessions': child_table,
}

# Run
instance._detect_relationships()

# Assert
expected_warning = (
'Could not automatically add relationships for all tables. The relationships in '
"the dataset are disjointed. Tables ['users', 'sessions'] are not connected to "
'any of the other tables.'
)
warnings_mock.warn.assert_called_once_with(expected_warning)
assert instance.relationships == []

def test__detect_relationships_circular(self):
"""Test that relationships that invalidate the metadata are not added."""
# Setup
Expand Down
Loading