Skip to content

Commit

Permalink
Working version
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Apr 2, 2024
1 parent ff23136 commit 9662bef
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 42 deletions.
17 changes: 12 additions & 5 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
BaseMultiTableSynthesizer.__init__(self, metadata, locales=locales)
self._table_sizes = {}
self._max_child_rows = {}
self._min_child_rows = {}
self._augmented_tables = []
self._learned_relationships = 0
self.verbose = verbose
Expand Down Expand Up @@ -364,6 +365,7 @@ def _augment_table(self, table, tables, table_name):
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
table[num_rows_key] = table[num_rows_key].fillna(0)
self._max_child_rows[num_rows_key] = table[num_rows_key].max()
self._min_child_rows[num_rows_key] = table[num_rows_key].min()
tables[table_name] = table
self._learned_relationships += 1

Expand Down Expand Up @@ -499,7 +501,6 @@ def _find_parent_id(likelihoods, num_rows):
int:
The parent id for this row, chosen based on likelihoods.
"""
"""
mean = likelihoods.mean()
if (likelihoods == 0).all():
# All rows got 0 likelihood, fallback to num_rows
Expand All @@ -513,8 +514,7 @@ def _find_parent_id(likelihoods, num_rows):
# at least one row got a valid likelihood, so fill the
# rows that got a singular matrix error with the mean
likelihoods = likelihoods.fillna(mean)
"""
likelihoods = num_rows

total = likelihoods.sum()
if total == 0:
# Worse case scenario: we have no likelihoods
Expand All @@ -524,8 +524,14 @@ def _find_parent_id(likelihoods, num_rows):
else:
weights = likelihoods.to_numpy() / total

chosen_parent = np.random.choice(likelihoods.index.to_list(), p=weights)
candidates, candidate_weights = [], []
for parent, weight in zip(likelihoods.index.to_list(), weights):
if num_rows[parent] > 0:
candidates.append(parent)
candidate_weights.append(weight)

candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights)
chosen_parent = np.random.choice(candidates, p=candidate_weights)
num_rows[chosen_parent] -= 1

return chosen_parent
Expand Down Expand Up @@ -600,7 +606,8 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
num_rows_column = f'__{child_name}__{foreign_key}__num_rows'
likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)

return likelihoods.apply(self._find_parent_id, axis=1, num_rows=parent_table[num_rows_column])
return likelihoods.apply(self._find_parent_id, axis=1,
num_rows=parent_table[num_rows_column].copy())

def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
for foreign_key in self.metadata._get_foreign_keys(parent_name, child_name):
Expand Down
73 changes: 41 additions & 32 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def _sample_rows(self, synthesizer, num_rows=None):
num_rows = synthesizer._num_rows
return synthesizer._sample_batch(int(num_rows), keep_extra_columns=True)

def _get_num_rows_from_parent(self, parent_row, child_name, foreign_key):
"""Get the number of rows to sample for the child from the parent row."""
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
num_rows = 0
if num_rows_key in parent_row.keys():
num_rows = parent_row[num_rows_key]

return num_rows

def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num_rows=None):
"""Sample the child rows that reference the parent row.
Expand All @@ -88,8 +97,7 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num
"""
# A child table is created based on only one foreign key.
foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0]
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
num_rows = parent_row[num_rows_key]
num_rows = self._get_num_rows_from_parent(parent_row, child_name, foreign_key)
child_synthesizer = self._recreate_child_synthesizer(child_name, parent_name, parent_row)
sampled_rows = self._sample_rows(child_synthesizer, num_rows)

Expand All @@ -113,48 +121,28 @@ def _enforce_table_sizes(self, child_name, table_name, scale, sampled_data):
total_num_rows = int(self._table_sizes[child_name] * scale)
for foreign_key in self.metadata._get_foreign_keys(table_name, child_name):
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
sampled_data[table_name][num_rows_key] = sampled_data[table_name][num_rows_key].fillna(0).round().clip(0, self._max_child_rows[num_rows_key])
min_rows = self._min_child_rows[num_rows_key]
max_rows = self._max_child_rows[num_rows_key]
key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows)

while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:
x = sampled_data[table_name][num_rows_key].argsort()
num_rows_column = sampled_data[table_name][num_rows_key].argsort()
if sum(sampled_data[table_name][num_rows_key]) < total_num_rows:
for i in x:
if sampled_data[table_name][num_rows_key][i] >= self._max_child_rows[num_rows_key]:
for i in num_rows_column:
if sampled_data[table_name][num_rows_key][i] >= max_rows:
break
sampled_data[table_name][num_rows_key][i] += 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
break
else:
for i in x[::-1]:
if sampled_data[table_name][num_rows_key][i] <= 0:
for i in num_rows_column[::-1]:
if sampled_data[table_name][num_rows_key][i] <= min_rows:
break
sampled_data[table_name][num_rows_key][i] -= 1
if sum(sampled_data[table_name][num_rows_key]) == total_num_rows:
break

assert sum(sampled_data[table_name][num_rows_key]) == total_num_rows
assert sampled_data[table_name][num_rows_key].to_list() == sampled_data[table_name][num_rows_key].clip(0, self._max_child_rows[num_rows_key]).to_list()

"""
while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:
# Select indices to increase or decrease the number of rows
idx = sampled_data[table_name][num_rows_key].argsort().values
if sum(sampled_data[table_name][num_rows_key]) < total_num_rows:
# Filter to indices where we can add 1 without exceeding the max
filtered_idx = idx[sampled_data[table_name][num_rows_key][idx] < self._max_child_rows[num_rows_key]]
# Figure out how many of the indices we need to use
delta = int(total_num_rows - sum(sampled_data[table_name][num_rows_key]))
# Add 1 to the first delta indices
sampled_data[table_name].iloc[num_rows_key,filtered_idx[:delta]] += 1
else:
# Filter to indices where we can subtract 1 without going below 0
filtered_idx = idx[sampled_data[table_name][num_rows_key][idx] > 0]
# Figure out how many of the indices we need to use
delta = int(sum(sampled_data[table_name][num_rows_key]) - total_num_rows)
# Subtract 1 from the last delta indices
sampled_data[table_name].iloc[num_rows_key,filtered_idx[-delta:]] -= 1
"""

def _sample_children(self, table_name, sampled_data, scale=1.0):
"""Recursively sample the children of a table.
Expand All @@ -179,7 +167,28 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
sampled_data=sampled_data,
)

self._sample_children(table_name=child_name, sampled_data=sampled_data)
if child_name not in sampled_data: # No child rows sampled, force row creation
foreign_key = self.metadata._get_foreign_keys(table_name, child_name)[0]
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
if num_rows_key in sampled_data[table_name].columns:
max_num_child_index = sampled_data[table_name][num_rows_key].idxmax()
parent_row = sampled_data[table_name].iloc[max_num_child_index]
else:
parent_row = sampled_data[table_name].sample().iloc[0]

self._add_child_rows(
child_name=child_name,
parent_name=table_name,
parent_row=parent_row,
sampled_data=sampled_data,
num_rows=1
)

self._sample_children(
table_name=child_name,
sampled_data=sampled_data,
scale=scale
)

def _finalize(self, sampled_data):
"""Remove extra columns from sampled tables and apply finishing touches.
Expand Down
11 changes: 6 additions & 5 deletions tests/unit/sampling/test_hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def test__sample_children(self):
``_sample_table`` does not sample the root parents of a graph, only the children.
"""
# Setup
def sample_children(table_name, sampled_data):
def sample_children(table_name, sampled_data, scale):
sampled_data['transactions'] = pd.DataFrame({
'transaction_id': [1, 2, 3],
'session_id': ['a', 'a', 'b']
Expand Down Expand Up @@ -281,7 +281,7 @@ def test__sample_children_no_rows_sampled(self):
value and force a child to be created from that row.
"""
# Setup
def sample_children(table_name, sampled_data):
def sample_children(table_name, sampled_data, scale):
sampled_data['transactions'] = pd.DataFrame({
'transaction_id': [1, 2],
'session_id': ['a', 'a']
Expand Down Expand Up @@ -354,7 +354,7 @@ def test__sample_children_no_rows_sampled_no_num_rows(self):
a child to be created from that row.
"""
# Setup
def sample_children(table_name, sampled_data):
def sample_children(table_name, sampled_data, scale):
sampled_data['transactions'] = pd.DataFrame({
'transaction_id': [1, 2],
'session_id': ['a', 'a']
Expand Down Expand Up @@ -514,7 +514,7 @@ def test__sample(self):
'transaction_amount': [100, 1000, 200]
})

def _sample_children_dummy(table_name, sampled_data):
def _sample_children_dummy(table_name, sampled_data, scale):
sampled_data['sessions'] = sessions
sampled_data['transactions'] = transactions

Expand Down Expand Up @@ -576,7 +576,8 @@ def _sample_children_dummy(table_name, sampled_data):
assert result == instance._finalize.return_value
instance._sample_children.assert_called_once_with(
table_name='users',
sampled_data=expected_sample
sampled_data=expected_sample,
scale=1.0
)
instance._add_foreign_key_columns.assert_has_calls([
call(
Expand Down

0 comments on commit 9662bef

Please sign in to comment.