Skip to content

Commit

Permalink
Fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Apr 7, 2024
1 parent a6082fc commit 654c44b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 5 deletions.
6 changes: 3 additions & 3 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ def _find_parent_id(likelihoods, num_rows):
candidates.append(parent)
candidate_weights.append(weight)

candidate_weights = np.nan_to_num(candidate_weights, nan=1e-6)
candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights)
chosen_parent = np.random.choice(candidates, p=candidate_weights)
num_rows[chosen_parent] -= 1
Expand Down Expand Up @@ -645,11 +646,10 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
# Create a copy of the parent table with the primary key as index to calculate likelihoods
primary_key = self.metadata.tables[parent_name].primary_key
parent_table = parent_table.set_index(primary_key)
num_rows_column = f'__{child_name}__{foreign_key}__num_rows'
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy()
likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)

return likelihoods.apply(self._find_parent_id, axis=1,
num_rows=parent_table[num_rows_column].copy())
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)

def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name):
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
A dictionary mapping table names to sampled tables (pd.DataFrame).
"""
for child_name in self.metadata._get_child_map()[table_name]:
self._enforce_table_sizes(child_name, table_name, scale, sampled_data)
if child_name not in sampled_data: # Sample based on only 1 parent
self._enforce_table_sizes(child_name, table_name, scale, sampled_data)
for _, row in sampled_data[table_name].iterrows():
self._add_child_rows(
child_name=child_name,
Expand Down
17 changes: 17 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from faker import Faker
from rdt.transformers import FloatFormatter
from sdmetrics.reports.multi_table import DiagnosticReport

from sdv import version
from sdv.datasets.demo import download_demo
Expand Down Expand Up @@ -1599,3 +1600,19 @@ def test_save_and_load_with_downgraded_version(tmp_path):
)
with pytest.raises(VersionError, match=error_msg):
HMASynthesizer.load(synthesizer_path)


def test_hma_relationship_validity():
"""Test the quality of the HMA synthesizer GH#1834."""
# Setup
data, metadata = download_demo('multi_table', 'Dunur_v1')
synthesizer = HMASynthesizer(metadata)
report = DiagnosticReport()

# Run
synthesizer.fit(data)
sample = synthesizer.sample()
report.generate(data, sample, metadata.to_dict(), verbose=False)

# Assert
assert report.get_details('Relationship Validity')['Score'].mean() == 1.0
92 changes: 91 additions & 1 deletion tests/unit/sampling/test_hierarchical_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from unittest.mock import Mock, call
from unittest.mock import MagicMock, Mock, call

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -594,3 +594,93 @@ def _sample_children_dummy(table_name, sampled_data, scale):
)
])
instance._finalize.assert_called_once_with(expected_sample)

def test___enforce_table_sizes_too_many_rows(self):
"""Test it enforces the sampled data to have the same size as the real data.
If the sampled data has more rows than the real data, _num_rows is decreased.
"""
# Setup
instance = MagicMock()
data = {
'parent': pd.DataFrame({
'fk': ['a', 'b', 'c'],
'__child__fk__num_rows': [1, 2, 3]
})
}
instance.metadata._get_foreign_keys.return_value = ['fk']
instance._min_child_rows = {'__child__fk__num_rows': 1}
instance._max_child_rows = {'__child__fk__num_rows': 3}
instance._table_sizes = {'child': 4}

# Run
BaseHierarchicalSampler._enforce_table_sizes(
instance,
'child',
'parent',
1.0,
data
)

# Assert
assert data['parent']['__child__fk__num_rows'].to_list() == [1, 1, 2]

def test___enforce_table_sizes_not_enough_rows(self):
"""Test it enforces the sampled data to have the same size as the real data.
If the sampled data has less rows than the real data, _num_rows is increased.
"""
# Setup
instance = MagicMock()
data = {
'parent': pd.DataFrame({
'fk': ['a', 'b', 'c'],
'__child__fk__num_rows': [1, 1, 1]
})
}
instance.metadata._get_foreign_keys.return_value = ['fk']
instance._min_child_rows = {'__child__fk__num_rows': 1}
instance._max_child_rows = {'__child__fk__num_rows': 3}
instance._table_sizes = {'child': 4}

# Run
BaseHierarchicalSampler._enforce_table_sizes(
instance,
'child',
'parent',
1.0,
data
)

# Assert
assert data['parent']['__child__fk__num_rows'].to_list() == [2, 1, 1]

def test___enforce_table_sizes_clipping(self):
"""Test it enforces the sampled data to have the same size as the real data.
When the sampled num_rows is outside the min and max range, it should be clipped.
"""
# Setup
instance = MagicMock()
data = {
'parent': pd.DataFrame({
'fk': ['a', 'b', 'c'],
'__child__fk__num_rows': [1, 2, 5]
})
}
instance.metadata._get_foreign_keys.return_value = ['fk']
instance._min_child_rows = {'__child__fk__num_rows': 2}
instance._max_child_rows = {'__child__fk__num_rows': 4}
instance._table_sizes = {'child': 8}

# Run
BaseHierarchicalSampler._enforce_table_sizes(
instance,
'child',
'parent',
1.0,
data
)

# Assert
assert data['parent']['__child__fk__num_rows'].to_list() == [2, 2, 4]

0 comments on commit 654c44b

Please sign in to comment.