From 66bedbbb356a75ca2686e6e3f0e39472a72d027b Mon Sep 17 00:00:00 2001 From: Felipe Date: Mon, 8 Apr 2024 10:33:33 -0500 Subject: [PATCH] Address feedback --- sdv/multi_table/hma.py | 8 +-- sdv/multi_table/utils.py | 2 +- sdv/sampling/hierarchical_sampler.py | 53 ++++++++++++++++--- .../sampling/test_hierarchical_sampler.py | 12 ++--- 4 files changed, 58 insertions(+), 17 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 8fc299b98..b62c0d3dc 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -113,8 +113,7 @@ def _estimate_columns_traversal(cls, metadata, table_name, columns_per_table[table_name] += \ cls._get_num_extended_columns( - metadata, child_name, table_name, columns_per_table, distributions - ) + metadata, child_name, table_name, columns_per_table, distributions) visited.add(table_name) @@ -563,6 +562,7 @@ def _find_parent_id(likelihoods, num_rows): # and all num_rows are 0, so we fallback to uniform length = len(likelihoods) weights = np.ones(length) / length + else: weights = likelihoods.to_numpy() / total @@ -572,7 +572,9 @@ def _find_parent_id(likelihoods, num_rows): candidates.append(parent) candidate_weights.append(weight) - candidate_weights = np.nan_to_num(candidate_weights, nan=1e-6) + if sum(candidate_weights) == 0: + return np.random.choice(candidates) + candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights) chosen_parent = np.random.choice(candidates, p=candidate_weights) num_rows[chosen_parent] -= 1 diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 2c372e6c6..852f2bda9 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -39,7 +39,7 @@ def _get_n_order_descendants(relationships, parent_table, order): descendants = {} order_1_descendants = _get_relationships_for_parent(relationships, parent_table) descendants['order_1'] = [rel['child_table_name'] for rel in order_1_descendants] - for i in range(2, order+1): + for i in range(2, order + 1): descendants[f'order_{i}'] = [] prov_descendants = [] for child_table in descendants[f'order_{i-1}']: diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 6a51abeff..71e749491 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -117,29 +117,68 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num sampled_data[child_name] = pd.concat( [previous, sampled_rows]).reset_index(drop=True) - def _enforce_table_sizes(self, child_name, table_name, scale, sampled_data): + def _enforce_table_size(self, child_name, table_name, scale, sampled_data): + """Ensure the child table has the same size as in the real data times the scale factor. + + This is accomplished by adjusting the number of rows to sample for each parent row. + If the sum of the values of the `__num_rows` column in the parent table is greater than + the real data table size * scale, the values are decreased. If the sum is lower, the + values are increased. + + The values are changed with the following algorithm: + + 1. Sort the `__num_rows` column. + 2. If the sum of the values is lower than the target, add 1 to the values from the lowest + to the highest until the sum is reached, while respecting the maximum values obsverved + in the real data when possible. + 3. If the sum of the values is higher than the target, subtract 1 from the values from the + highest to the lowest until the sum is reached, while respecting the minimum values + observed in the real data when possible. + + Args: + child_name (str): + The name of the child table. + table_name (str): + The name of the parent table. + scale (float): + The scale factor to apply to the table size. + sampled_data (dict): + A dictionary mapping table names to sampled data (pd.DataFrame). + """ total_num_rows = int(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): num_rows_key = f'__{child_name}__{foreign_key}__num_rows' - min_rows = self._min_child_rows[num_rows_key] + min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key] max_rows = self._max_child_rows[num_rows_key] key_data = sampled_data[table_name][num_rows_key].fillna(0).round() sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows) while sum(sampled_data[table_name][num_rows_key]) != total_num_rows: num_rows_column = sampled_data[table_name][num_rows_key].argsort() + if sum(sampled_data[table_name][num_rows_key]) < total_num_rows: for i in num_rows_column: - if sampled_data[table_name][num_rows_key][i] >= max_rows: + # If the number of rows is already at the maximum, skip + # The exception is when the smallest value is already at the maximum, + # in which case we ignore the boundary + if sampled_data[table_name][num_rows_key][i] >= max_rows and \ + num_rows_column.iloc[0] < max_rows: break - sampled_data[table_name][num_rows_key][i] += 1 + + sampled_data[table_name].loc[i, num_rows_key] += 1 if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: break + else: for i in num_rows_column[::-1]: - if sampled_data[table_name][num_rows_key][i] <= min_rows: + # If the number of rows is already at the minimum, skip + # The exception is when the highest value is already at the minimum, + # in which case we ignore the boundary + if sampled_data[table_name][num_rows_key][i] <= min_rows and \ + num_rows_column.iloc[-1] > min_rows: break - sampled_data[table_name][num_rows_key][i] -= 1 + + sampled_data[table_name].loc[i, num_rows_key] -= 1 if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: break @@ -157,7 +196,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): A dictionary mapping table names to sampled tables (pd.DataFrame). """ for child_name in self.metadata._get_child_map()[table_name]: - self._enforce_table_sizes(child_name, table_name, scale, sampled_data) + self._enforce_table_size(child_name, table_name, scale, sampled_data) if child_name not in sampled_data: # Sample based on only 1 parent for _, row in sampled_data[table_name].iterrows(): self._add_child_rows( diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index c1b5f348b..c331ee96d 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -595,7 +595,7 @@ def _sample_children_dummy(table_name, sampled_data, scale): ]) instance._finalize.assert_called_once_with(expected_sample) - def test___enforce_table_sizes_too_many_rows(self): + def test___enforce_table_size_too_many_rows(self): """Test it enforces the sampled data to have the same size as the real data. If the sampled data has more rows than the real data, _num_rows is decreased. @@ -614,7 +614,7 @@ def test___enforce_table_sizes_too_many_rows(self): instance._table_sizes = {'child': 4} # Run - BaseHierarchicalSampler._enforce_table_sizes( + BaseHierarchicalSampler._enforce_table_size( instance, 'child', 'parent', @@ -625,7 +625,7 @@ def test___enforce_table_sizes_too_many_rows(self): # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [1, 1, 2] - def test___enforce_table_sizes_not_enough_rows(self): + def test___enforce_table_size_not_enough_rows(self): """Test it enforces the sampled data to have the same size as the real data. If the sampled data has less rows than the real data, _num_rows is increased. @@ -644,7 +644,7 @@ def test___enforce_table_sizes_not_enough_rows(self): instance._table_sizes = {'child': 4} # Run - BaseHierarchicalSampler._enforce_table_sizes( + BaseHierarchicalSampler._enforce_table_size( instance, 'child', 'parent', @@ -655,7 +655,7 @@ def test___enforce_table_sizes_not_enough_rows(self): # Assert assert data['parent']['__child__fk__num_rows'].to_list() == [2, 1, 1] - def test___enforce_table_sizes_clipping(self): + def test___enforce_table_size_clipping(self): """Test it enforces the sampled data to have the same size as the real data. When the sampled num_rows is outside the min and max range, it should be clipped. @@ -674,7 +674,7 @@ def test___enforce_table_sizes_clipping(self): instance._table_sizes = {'child': 8} # Run - BaseHierarchicalSampler._enforce_table_sizes( + BaseHierarchicalSampler._enforce_table_size( instance, 'child', 'parent',