Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Apr 9, 2024
1 parent 562daa7 commit 8b79999
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 54 deletions.
11 changes: 6 additions & 5 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,11 @@ def _find_parent_id(likelihoods, num_rows):

# All available candidates were assigned 0 likelihood of being the parent id
if sum(candidate_weights) == 0:
return np.random.choice(candidates)
chosen_parent = np.random.choice(candidates)
else:
candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights)
chosen_parent = np.random.choice(candidates, p=candidate_weights)

candidate_weights = np.array(candidate_weights) / np.sum(candidate_weights)
chosen_parent = np.random.choice(candidates, p=candidate_weights)
num_rows[chosen_parent] -= 1

return chosen_parent
Expand Down Expand Up @@ -649,9 +650,9 @@ def _find_parent_ids(self, child_table, parent_table, child_name, parent_name, f
# Create a copy of the parent table with the primary key as index to calculate likelihoods
primary_key = self.metadata.tables[parent_name].primary_key
parent_table = parent_table.set_index(primary_key)
num_rows = round(parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy())
likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
num_rows = parent_table[f'__{child_name}__{foreign_key}__num_rows'].copy()

likelihoods = self._get_likelihoods(child_table, parent_table, child_name, foreign_key)
return likelihoods.apply(self._find_parent_id, axis=1, num_rows=num_rows)

def _add_foreign_key_columns(self, child_table, parent_table, child_name, parent_name):
Expand Down
2 changes: 1 addition & 1 deletion sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_n_order_descendants(relationships, parent_table, order):
descendants = {}
order_1_descendants = _get_relationships_for_parent(relationships, parent_table)
descendants['order_1'] = [rel['child_table_name'] for rel in order_1_descendants]
for i in range(2, order + 1):
for i in range(2, order+1):
descendants[f'order_{i}'] = []
prov_descendants = []
for child_table in descendants[f'order_{i-1}']:
Expand Down
23 changes: 7 additions & 16 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,7 @@ def _sample_rows(self, synthesizer, num_rows=None):
if num_rows is None:
num_rows = synthesizer._num_rows

return synthesizer._sample_batch(num_rows, keep_extra_columns=True)

def _get_num_rows_from_parent(self, parent_row, child_name, foreign_key):
"""Get the number of rows to sample for the child from the parent row."""
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
num_rows = 0
if num_rows_key in parent_row.keys():
num_rows = parent_row[num_rows_key]

return round(num_rows)
return synthesizer._sample_batch(round(num_rows), keep_extra_columns=True)

def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num_rows=None):
"""Sample the child rows that reference the parent row.
Expand All @@ -96,10 +87,10 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num
Number of rows to sample. If None, infers number of child rows to sample
from the parent row. Defaults to None.
"""
# A child table is created based on only one foreign key
# A child table is created based on only one foreign key.
foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0]

num_rows = self._get_num_rows_from_parent(parent_row, child_name, foreign_key)
if num_rows is None:
num_rows = parent_row[f'__{child_name}__{foreign_key}__num_rows']
child_synthesizer = self._recreate_child_synthesizer(child_name, parent_name, parent_row)
sampled_rows = self._sample_rows(child_synthesizer, num_rows)

Expand Down Expand Up @@ -150,10 +141,10 @@ def _enforce_table_size(self, child_name, table_name, scale, sampled_data):
total_num_rows = round(self._table_sizes[child_name] * scale)
for foreign_key in self.metadata._get_foreign_keys(table_name, child_name):
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
min_rows = getattr(self, '_min_child_rows', {num_rows_key: 0})[num_rows_key]
max_rows = self._max_child_rows[num_rows_key]
sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows)
key_data = sampled_data[table_name][num_rows_key].fillna(0).round()
sampled_data[table_name][num_rows_key] = key_data.clip(min_rows, max_rows).astype(int)

while sum(sampled_data[table_name][num_rows_key]) != total_num_rows:
num_rows_column = sampled_data[table_name][num_rows_key].argsort()
Expand Down Expand Up @@ -206,7 +197,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
child_name=child_name,
parent_name=table_name,
parent_row=row,
sampled_data=sampled_data,
sampled_data=sampled_data
)

if child_name not in sampled_data: # No child rows sampled, force row creation
Expand Down
22 changes: 11 additions & 11 deletions tests/integration/utils/test_poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,30 +171,30 @@ def test_simplify_schema(capsys):
# Assert
expected_message_before = re.compile(
r'PerformanceAlert: Using the HMASynthesizer on this metadata schema is not recommended\.'
r' To model this data, HMA will generate a large number of columns\. \(203427 columns\)\s+'
r' To model this data, HMA will generate a large number of columns\. \(173818 columns\)\s+'
r'Table Name\s*#\s*Columns in Metadata\s*Est # Columns\s*'
r'match_stats\s*25\s*25\s*'
r'matches\s*45\s*446\s*'
r'players\s*13\s*414\s*'
r'teams\s*101\s*202542\s*'
r'match_stats\s*24\s*24\s*'
r'matches\s*39\s*412\s*'
r'players\s*5\s*378\s*'
r'teams\s*1\s*173004\s*'
r"We recommend simplifying your metadata schema using 'sdv.utils.poc.simplify_schema'\.\s*"
r'If this is not possible, contact us at [email protected] for enterprise solutions\.'
)
expected_message_after = re.compile(
r'Success! The schema has been simplified\.\s+'
r'Table Name\s*#\s*Columns \(Before\)\s*#\s*Columns \(After\)\s*'
r'match_stats\s*29\s*4\s*'
r'matches\s*48\s*21\s*'
r'players\s*14\s*0\s*'
r'teams\s*102\s*102'
r'match_stats\s*28\s*4\s*'
r'matches\s*42\s*21\s*'
r'players\s*6\s*0\s*'
r'teams\s*2\s*2'
)
assert expected_message_before.match(captured_before_simplification.out.strip())
assert expected_message_after.match(captured_after_simplification.out.strip())
metadata_simplify.validate()
metadata_simplify.validate_data(data_simplify)
num_estimated_column_after_simplification = _get_total_estimated_columns(metadata_simplify)
assert num_estimated_column_before_simplification == 203427
assert num_estimated_column_after_simplification == 617
assert num_estimated_column_before_simplification == 173818
assert num_estimated_column_after_simplification == 517


def test_simpliy_nothing_to_simplify():
Expand Down
25 changes: 4 additions & 21 deletions tests/unit/sampling/test_hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,10 @@ def test__sample_rows_missing_num_rows(self):
keep_extra_columns=True
)

def test__get_num_rows_from_parent(self):
"""Test that the number of child rows is extracted from the parent row."""
# Setup
parent_row = pd.Series({
'__sessions__user_id__num_rows': 10,
})
instance = Mock()
instance._max_child_rows = {'__sessions__user_id__num_rows': 10}

# Run
result = BaseHierarchicalSampler._get_num_rows_from_parent(
instance, parent_row, 'sessions', 'user_id')

# Assert
expected_result = 10.0
assert result == expected_result

def test__add_child_rows(self):
"""Test adding child rows when sampled data is empty."""
# Setup
instance = Mock()
instance._get_num_rows_from_parent.return_value = 10
child_synthesizer_mock = Mock()
instance._recreate_child_synthesizer.return_value = child_synthesizer_mock

Expand All @@ -121,7 +103,8 @@ def test__add_child_rows(self):
})
parent_row = pd.DataFrame({
'user_id': [1, 2, 3],
'name': ['John', 'Doe', 'Johanna']
'name': ['John', 'Doe', 'Johanna'],
'__sessions__user_id__num_rows': [10, 10, 10]
})
sampled_data = {}

Expand All @@ -146,7 +129,6 @@ def test__add_child_rows_with_sampled_data(self):
"""
# Setup
instance = Mock()
instance._get_num_rows_from_parent.return_value = 10
child_synthesizer_mock = Mock()
instance._recreate_child_synthesizer.return_value = child_synthesizer_mock

Expand All @@ -169,7 +151,8 @@ def test__add_child_rows_with_sampled_data(self):
})
parent_row = pd.DataFrame({
'user_id': [1, 2, 3],
'name': ['John', 'Doe', 'Johanna']
'name': ['John', 'Doe', 'Johanna'],
'__sessions__user_id__num_rows': [10, 10, 10]
})
sampled_data = {
'sessions': pd.DataFrame({
Expand Down

0 comments on commit 8b79999

Please sign in to comment.