Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Apr 8, 2024
1 parent 654c44b commit 66bedbb
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
8 changes: 5 additions & 3 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def _estimate_columns_traversal(cls, metadata, table_name,

columns_per_table[table_name] += \
cls._get_num_extended_columns(
metadata, child_name, table_name, columns_per_table, distributions
)
metadata, child_name, table_name, columns_per_table, distributions)

visited.add(table_name)

Expand Down Expand Up @@ -563,6 +562,7 @@ def _find_parent_id(likelihoods, num_rows):
# and all num_rows are 0, so we fallback to uniform
length = len(likelihoods)
weights = np.ones(length) / length

else:
weights = likelihoods.to_numpy() / total

Expand All @@ -572,7 +572,9 @@ def _find_parent_id(likelihoods, num_rows):
candidates.append(parent)
candidate_weights.append(weight)

candidate_weights = np.nan_to_num(candidate_weights, nan=1e-6)
if sum(candidate_weights) == 0:
return np.random.choice(candidates)

candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights)
chosen_parent = np.random.choice(candidates, p=candidate_weights)
num_rows[chosen_parent] -= 1
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_n_order_descendants(relationships, parent_table, order):
descendants = {}
order_1_descendants = _get_relationships_for_parent(relationships, parent_table)
descendants['order_1'] = [rel['child_table_name'] for rel in order_1_descendants]
for i in range(2, order+1):
for i in range(2, order + 1):
descendants[f'order_{i}'] = []
prov_descendants = []
for child_table in descendants[f'order_{i-1}']:
Expand Down
53 changes: 46 additions & 7 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,68 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num
sampled_data[child_name] = pd.concat(
[previous, sampled_rows]).reset_index(drop=True)

def _enforce_table_sizes(self, child_name, table_name, scale, sampled_data):
def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
"""Ensure the child table has the same size as in the real data times the scale factor.
This is accomplished by adjusting the number of rows to sample for each parent row.
If the sum of the values of the `__num_rows` column in the parent table is greater than
the real data table size * scale, the values are decreased. If the sum is lower, the
values are increased.
The values are changed with the following algorithm:
1. Sort the `__num_rows` column.
2. If the sum of the values is lower than the target, add 1 to the values from the lowest
to the highest until the sum is reached, while respecting the maximum values obsverved
in the real data when possible.
3. If the sum of the values is higher than the target, subtract 1 from the values from the
highest to the lowest until the sum is reached, while respecting the minimum values
observed in the real data when possible.
Args:
child_name (str):
The name of the child table.
table_name (str):
The name of the parent table.
scale (float):
The scale factor to apply to the table size.
sampled_data (dict):
A dictionary mapping table names to sampled data (pd.DataFrame).
"""
total_num_rows = int(self._table_sizes[child_name] * scale)
for foreign_key in self.metadata._get_foreign_keys(table_name, child_name):
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
min_rows = self._min_child_rows[num_rows_key]
min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key]
max_rows = self._max_child_rows[num_rows_key]
key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows)

while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:
num_rows_column = sampled_data[table_name][num_rows_key].argsort()

if sum(sampled_data[table_name][num_rows_key]) < total_num_rows:
for i in num_rows_column:
if sampled_data[table_name][num_rows_key][i] >= max_rows:
# If the number of rows is already at the maximum, skip
# The exception is when the smallest value is already at the maximum,
# in which case we ignore the boundary
if sampled_data[table_name][num_rows_key][i] >= max_rows and \
num_rows_column.iloc[0] < max_rows:
break
sampled_data[table_name][num_rows_key][i] += 1

sampled_data[table_name].loc[i, num_rows_key] += 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
break

else:
for i in num_rows_column[::-1]:
if sampled_data[table_name][num_rows_key][i] <= min_rows:
# If the number of rows is already at the minimum, skip
# The exception is when the highest value is already at the minimum,
# in which case we ignore the boundary
if sampled_data[table_name][num_rows_key][i] <= min_rows and \
num_rows_column.iloc[-1] > min_rows:
break
sampled_data[table_name][num_rows_key][i] -= 1

sampled_data[table_name].loc[i, num_rows_key] -= 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
break

Expand All @@ -157,7 +196,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
A dictionary mapping table names to sampled tables (pd.DataFrame).
"""
for child_name in self.metadata._get_child_map()[table_name]:
self._enforce_table_sizes(child_name, table_name, scale, sampled_data)
self._enforce_table_size(child_name, table_name, scale, sampled_data)
if child_name not in sampled_data: # Sample based on only 1 parent
for _, row in sampled_data[table_name].iterrows():
self._add_child_rows(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/sampling/test_hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _sample_children_dummy(table_name, sampled_data, scale):
])
instance._finalize.assert_called_once_with(expected_sample)

def test___enforce_table_sizes_too_many_rows(self):
def test___enforce_table_size_too_many_rows(self):
"""Test it enforces the sampled data to have the same size as the real data.
If the sampled data has more rows than the real data, _num_rows is decreased.
Expand All @@ -614,7 +614,7 @@ def test___enforce_table_sizes_too_many_rows(self):
instance._table_sizes = {'child': 4}

# Run
BaseHierarchicalSampler._enforce_table_sizes(
BaseHierarchicalSampler._enforce_table_size(
instance,
'child',
'parent',
Expand All @@ -625,7 +625,7 @@ def test___enforce_table_sizes_too_many_rows(self):
# Assert
assert data['parent']['__child__fk__num_rows'].to_list() == [1, 1, 2]

def test___enforce_table_sizes_not_enough_rows(self):
def test___enforce_table_size_not_enough_rows(self):
"""Test it enforces the sampled data to have the same size as the real data.
If the sampled data has less rows than the real data, _num_rows is increased.
Expand All @@ -644,7 +644,7 @@ def test___enforce_table_sizes_not_enough_rows(self):
instance._table_sizes = {'child': 4}

# Run
BaseHierarchicalSampler._enforce_table_sizes(
BaseHierarchicalSampler._enforce_table_size(
instance,
'child',
'parent',
Expand All @@ -655,7 +655,7 @@ def test___enforce_table_sizes_not_enough_rows(self):
# Assert
assert data['parent']['__child__fk__num_rows'].to_list() == [2, 1, 1]

def test___enforce_table_sizes_clipping(self):
def test___enforce_table_size_clipping(self):
"""Test it enforces the sampled data to have the same size as the real data.
When the sampled num_rows is outside the min and max range, it should be clipped.
Expand All @@ -674,7 +674,7 @@ def test___enforce_table_sizes_clipping(self):
instance._table_sizes = {'child': 8}

# Run
BaseHierarchicalSampler._enforce_table_sizes(
BaseHierarchicalSampler._enforce_table_size(
instance,
'child',
'parent',
Expand Down

0 comments on commit 66bedbb

Please sign in to comment.