From cf5154e2a978987fc01c6a93aea9d246c0086b76 Mon Sep 17 00:00:00 2001 From: Felipe Date: Wed, 3 Apr 2024 09:35:41 -0700 Subject: [PATCH] Fix bug --- sdv/multi_table/hma.py | 6 +- sdv/sampling/hierarchical_sampler.py | 2 +- tests/integration/multi_table/test_hma.py | 17 ++++ .../sampling/test_hierarchical_sampler.py | 92 ++++++++++++++++++- 4 files changed, 112 insertions(+), 5 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 43ceb9600..7b5839144 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -535,6 +535,7 @@ def _find_parent_id(likelihoods, num_rows): candidates.append(parent) candidate_weights.append(weight) + candidate_weights = np.nan_to_num(candidate_weights, nan=1e-6) candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights) chosen_parent = np.random.choice(candidates, p=candidate_weights) num_rows[chosen_parent] -= 1 @@ -608,11 +609,10 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f # Create a copy of the parent table with the primary key as index to calculate likelihoods primary_key = self.metadata.tables[parent_name].primary_key parent_table = parent_table.set_index(primary_key) - num_rows_column = f'__{child_name}__{foreign_key}__num_rows' + num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy() likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key) - return likelihoods.apply(self._find_parent_id, axis=1, - num_rows=parent_table[num_rows_column].copy()) + return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows) def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name): for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name): diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 7867b93b5..6a51abeff 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -157,8 +157,8 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): A dictionary mapping table names to sampled tables (pd.DataFrame). """ for child_name in self.metadata._get_child_map()[table_name]: + self._enforce_table_sizes(child_name, table_name, scale, sampled_data) if child_name not in sampled_data: # Sample based on only 1 parent - self._enforce_table_sizes(child_name, table_name, scale, sampled_data) for _, row in sampled_data[table_name].iterrows(): self._add_child_rows( child_name=child_name, diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4ea365a09..3d47dba1b 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -8,6 +8,7 @@ import pytest from faker import Faker from rdt.transformers import FloatFormatter +from sdmetrics.reports.multi_table import DiagnosticReport from sdv import version from sdv.datasets.demo import download_demo @@ -1599,3 +1600,19 @@ def test_save_and_load_with_downgraded_version(tmp_path): ) with pytest.raises(VersionError, match=error_msg): HMASynthesizer.load(synthesizer_path) + + +def test_hma_relationship_validity(): + """Test the quality of the HMA synthesizer GH#1834.""" + # Setup + data, metadata = download_demo('multi_table', 'Dunur_v1') + synthesizer = HMASynthesizer(metadata) + report = DiagnosticReport() + + # Run + synthesizer.fit(data) + sample = synthesizer.sample() + report.generate(data, sample, metadata.to_dict(), verbose=False) + + # Assert + assert report.get_details('Relationship Validity')['Score'].mean() == 1.0 diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index ed1ed95f1..c1b5f348b 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -1,5 +1,5 @@ from collections import defaultdict -from unittest.mock import Mock, call +from unittest.mock import MagicMock, Mock, call import numpy as np import pandas as pd @@ -594,3 +594,93 @@ def _sample_children_dummy(table_name, sampled_data, scale): ) ]) instance._finalize.assert_called_once_with(expected_sample) + + def test___enforce_table_sizes_too_many_rows(self): + """Test it enforces the sampled data to have the same size as the real data. + + If the sampled data has more rows than the real data, _num_rows is decreased. + """ + # Setup + instance = MagicMock() + data = { + 'parent': pd.DataFrame({ + 'fk': ['a', 'b', 'c'], + '__child__fk__num_rows': [1, 2, 3] + }) + } + instance.metadata._get_foreign_keys.return_value = ['fk'] + instance._min_child_rows = {'__child__fk__num_rows': 1} + instance._max_child_rows = {'__child__fk__num_rows': 3} + instance._table_sizes = {'child': 4} + + # Run + BaseHierarchicalSampler._enforce_table_sizes( + instance, + 'child', + 'parent', + 1.0, + data + ) + + # Assert + assert data['parent']['__child__fk__num_rows'].to_list() == [1, 1, 2] + + def test___enforce_table_sizes_not_enough_rows(self): + """Test it enforces the sampled data to have the same size as the real data. + + If the sampled data has less rows than the real data, _num_rows is increased. + """ + # Setup + instance = MagicMock() + data = { + 'parent': pd.DataFrame({ + 'fk': ['a', 'b', 'c'], + '__child__fk__num_rows': [1, 1, 1] + }) + } + instance.metadata._get_foreign_keys.return_value = ['fk'] + instance._min_child_rows = {'__child__fk__num_rows': 1} + instance._max_child_rows = {'__child__fk__num_rows': 3} + instance._table_sizes = {'child': 4} + + # Run + BaseHierarchicalSampler._enforce_table_sizes( + instance, + 'child', + 'parent', + 1.0, + data + ) + + # Assert + assert data['parent']['__child__fk__num_rows'].to_list() == [2, 1, 1] + + def test___enforce_table_sizes_clipping(self): + """Test it enforces the sampled data to have the same size as the real data. + + When the sampled num_rows is outside the min and max range, it should be clipped. + """ + # Setup + instance = MagicMock() + data = { + 'parent': pd.DataFrame({ + 'fk': ['a', 'b', 'c'], + '__child__fk__num_rows': [1, 2, 5] + }) + } + instance.metadata._get_foreign_keys.return_value = ['fk'] + instance._min_child_rows = {'__child__fk__num_rows': 2} + instance._max_child_rows = {'__child__fk__num_rows': 4} + instance._table_sizes = {'child': 8} + + # Run + BaseHierarchicalSampler._enforce_table_sizes( + instance, + 'child', + 'parent', + 1.0, + data + ) + + # Assert + assert data['parent']['__child__fk__num_rows'].to_list() == [2, 2, 4]