diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 4486463ce..85673503f 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -37,6 +37,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales) self._table_sizes = {} self._max_child_rows = {} + self._min_child_rows = {} self._augmented_tables = [] self._learned_relationships = 0 self.verbose = verbose @@ -364,6 +365,7 @@ def _augment_table(self, table, tables, table_name): num_rows_key = f'__{child_name}__{foreign_key}__num_rows' table[num_rows_key] = table[num_rows_key].fillna(0) self._max_child_rows[num_rows_key] = table[num_rows_key].max() + self._min_child_rows[num_rows_key] = table[num_rows_key].min() tables[table_name] = table self._learned_relationships += 1 @@ -499,7 +501,6 @@ def _find_parent_id(likelihoods, num_rows): int: The parent id for this row, chosen based on likelihoods. """ - """ mean = likelihoods.mean() if (likelihoods == 0).all(): # All rows got 0 likelihood, fallback to num_rows @@ -513,8 +514,7 @@ def _find_parent_id(likelihoods, num_rows): # at least one row got a valid likelihood, so fill the # rows that got a singular matrix error with the mean likelihoods = likelihoods.fillna(mean) - """ - likelihoods = num_rows + total = likelihoods.sum() if total == 0: # Worse case scenario: we have no likelihoods @@ -524,8 +524,14 @@ def _find_parent_id(likelihoods, num_rows): else: weights = likelihoods.to_numpy() / total - chosen_parent = np.random.choice(likelihoods.index.to_list(), p=weights) + candidates, candidate_weights = [], [] + for parent, weight in zip(likelihoods.index.to_list(), weights): + if num_rows[parent] > 0: + candidates.append(parent) + candidate_weights.append(weight) + candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights) + chosen_parent = np.random.choice(candidates, p=candidate_weights) num_rows[chosen_parent] -= 1 return chosen_parent @@ -600,7 +606,8 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f num_rows_column = f'__{child_name}__{foreign_key}__num_rows' likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key) - return likelihoods.apply(self._find_parent_id, axis=1, num_rows=parent_table[num_rows_column]) + return likelihoods.apply(self._find_parent_id, axis=1, + num_rows=parent_table[num_rows_column].copy()) def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name): for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name): diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 7fad51109..7867b93b5 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -70,6 +70,15 @@ def _sample_rows(self, synthesizer, num_rows=None): num_rows = synthesizer._num_rows return synthesizer._sample_batch(int(num_rows), keep_extra_columns=True) + def _get_num_rows_from_parent(self, parent_row, child_name, foreign_key): + """Get the number of rows to sample for the child from the parent row.""" + num_rows_key = f'__{child_name}__{foreign_key}__num_rows' + num_rows = 0 + if num_rows_key in parent_row.keys(): + num_rows = parent_row[num_rows_key] + + return num_rows + def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num_rows=None): """Sample the child rows that reference the parent row. @@ -88,8 +97,7 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num """ # A child table is created based on only one foreign key. foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] - num_rows_key = f'__{child_name}__{foreign_key}__num_rows' - num_rows = parent_row[num_rows_key] + num_rows = self._get_num_rows_from_parent(parent_row, child_name, foreign_key) child_synthesizer = self._recreate_child_synthesizer(child_name, parent_name, parent_row) sampled_rows = self._sample_rows(child_synthesizer, num_rows) @@ -113,48 +121,28 @@ def _enforce_table_sizes(self, child_name, table_name, scale, sampled_data): total_num_rows = int(self._table_sizes[child_name] * scale) for foreign_key in self.metadata._get_foreign_keys(table_name, child_name): num_rows_key = f'__{child_name}__{foreign_key}__num_rows' - sampled_data[table_name][num_rows_key] = sampled_data[table_name][num_rows_key].fillna(0).round().clip(0, self._max_child_rows[num_rows_key]) + min_rows = self._min_child_rows[num_rows_key] + max_rows = self._max_child_rows[num_rows_key] + key_data = sampled_data[table_name][num_rows_key].fillna(0).round() + sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows) while sum(sampled_data[table_name][num_rows_key]) != total_num_rows: - x = sampled_data[table_name][num_rows_key].argsort() + num_rows_column = sampled_data[table_name][num_rows_key].argsort() if sum(sampled_data[table_name][num_rows_key]) < total_num_rows: - for i in x: - if sampled_data[table_name][num_rows_key][i] >= self._max_child_rows[num_rows_key]: + for i in num_rows_column: + if sampled_data[table_name][num_rows_key][i] >= max_rows: break sampled_data[table_name][num_rows_key][i] += 1 if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: break else: - for i in x[::-1]: - if sampled_data[table_name][num_rows_key][i] <= 0: + for i in num_rows_column[::-1]: + if sampled_data[table_name][num_rows_key][i] <= min_rows: break sampled_data[table_name][num_rows_key][i] -= 1 if sum(sampled_data[table_name][num_rows_key]) == total_num_rows: break - assert sum(sampled_data[table_name][num_rows_key]) == total_num_rows - assert sampled_data[table_name][num_rows_key].to_list() == sampled_data[table_name][num_rows_key].clip(0, self._max_child_rows[num_rows_key]).to_list() - - """ - while sum(sampled_data[table_name][num_rows_key]) != total_num_rows: - # Select indices to increase or decrease the number of rows - idx = sampled_data[table_name][num_rows_key].argsort().values - if sum(sampled_data[table_name][num_rows_key]) < total_num_rows: - # Filter to indices where we can add 1 without exceeding the max - filtered_idx = idx[sampled_data[table_name][num_rows_key][idx] < self._max_child_rows[num_rows_key]] - # Figure out how many of the indices we need to use - delta = int(total_num_rows - sum(sampled_data[table_name][num_rows_key])) - # Add 1 to the first delta indices - sampled_data[table_name].iloc[num_rows_key,filtered_idx[:delta]] += 1 - else: - # Filter to indices where we can subtract 1 without going below 0 - filtered_idx = idx[sampled_data[table_name][num_rows_key][idx] > 0] - # Figure out how many of the indices we need to use - delta = int(sum(sampled_data[table_name][num_rows_key]) - total_num_rows) - # Subtract 1 from the last delta indices - sampled_data[table_name].iloc[num_rows_key,filtered_idx[-delta:]] -= 1 - """ - def _sample_children(self, table_name, sampled_data, scale=1.0): """Recursively sample the children of a table. @@ -179,7 +167,28 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): sampled_data=sampled_data, ) - self._sample_children(table_name=child_name, sampled_data=sampled_data) + if child_name not in sampled_data: # No child rows sampled, force row creation + foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0] + num_rows_key = f'__{child_name}__{foreign_key}__num_rows' + if num_rows_key in sampled_data[table_name].columns: + max_num_child_index = sampled_data[table_name][num_rows_key].idxmax() + parent_row = sampled_data[table_name].iloc[max_num_child_index] + else: + parent_row = sampled_data[table_name].sample().iloc[0] + + self._add_child_rows( + child_name=child_name, + parent_name=table_name, + parent_row=parent_row, + sampled_data=sampled_data, + num_rows=1 + ) + + self._sample_children( + table_name=child_name, + sampled_data=sampled_data, + scale=scale + ) def _finalize(self, sampled_data): """Remove extra columns from sampled tables and apply finishing touches. diff --git a/tests/unit/sampling/test_hierarchical_sampler.py b/tests/unit/sampling/test_hierarchical_sampler.py index c9c256608..ed1ed95f1 100644 --- a/tests/unit/sampling/test_hierarchical_sampler.py +++ b/tests/unit/sampling/test_hierarchical_sampler.py @@ -199,7 +199,7 @@ def test__sample_children(self): ``_sample_table`` does not sample the root parents of a graph, only the children. """ # Setup - def sample_children(table_name, sampled_data): + def sample_children(table_name, sampled_data, scale): sampled_data['transactions'] = pd.DataFrame({ 'transaction_id': [1, 2, 3], 'session_id': ['a', 'a', 'b'] @@ -281,7 +281,7 @@ def test__sample_children_no_rows_sampled(self): value and force a child to be created from that row. """ # Setup - def sample_children(table_name, sampled_data): + def sample_children(table_name, sampled_data, scale): sampled_data['transactions'] = pd.DataFrame({ 'transaction_id': [1, 2], 'session_id': ['a', 'a'] @@ -354,7 +354,7 @@ def test__sample_children_no_rows_sampled_no_num_rows(self): a child to be created from that row. """ # Setup - def sample_children(table_name, sampled_data): + def sample_children(table_name, sampled_data, scale): sampled_data['transactions'] = pd.DataFrame({ 'transaction_id': [1, 2], 'session_id': ['a', 'a'] @@ -514,7 +514,7 @@ def test__sample(self): 'transaction_amount': [100, 1000, 200] }) - def _sample_children_dummy(table_name, sampled_data): + def _sample_children_dummy(table_name, sampled_data, scale): sampled_data['sessions'] = sessions sampled_data['transactions'] = transactions @@ -576,7 +576,8 @@ def _sample_children_dummy(table_name, sampled_data): assert result == instance._finalize.return_value instance._sample_children.assert_called_once_with( table_name='users', - sampled_data=expected_sample + sampled_data=expected_sample, + scale=1.0 ) instance._add_foreign_key_columns.assert_has_calls([ call(