Skip to content

Commit

Permalink
Set numerical distributions to truncnorm for extension columns (#2068)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni authored Jun 20, 2024
1 parent a04dc68 commit e06ceca
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 22 deletions.
40 changes: 29 additions & 11 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Hierarchical Modeling Algorithms."""

import logging
from collections import defaultdict
from copy import deepcopy

import numpy as np
Expand All @@ -15,6 +16,7 @@

LOGGER = logging.getLogger(__name__)
MAX_NUMBER_OF_COLUMNS = 1000
DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm'


class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
Expand Down Expand Up @@ -159,6 +161,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True):
self._augmented_tables = []
self._learned_relationships = 0
self._default_parameters = {}
self._parent_extended_columns = defaultdict(list)
self.verbose = verbose
BaseHierarchicalSampler.__init__(
self, self.metadata, self._table_synthesizers, self._table_sizes
Expand Down Expand Up @@ -215,8 +218,8 @@ def _get_distributions(self):
distributions = {}
for table in self.metadata.tables:
parameters = self.get_table_parameters(table)
sythesizer_parameter = parameters.get('synthesizer_parameters', {})
distributions[table] = sythesizer_parameter.get('default_distribution', None)
synthesizer_parameters = parameters.get('synthesizer_parameters', {})
distributions[table] = synthesizer_parameters.get('default_distribution', None)

return distributions

Expand Down Expand Up @@ -268,6 +271,13 @@ def preprocess(self, data):

return processed_data

def _set_extended_columns_distributions(self, synthesizer, table_name, valid_columns):
numerical_distributions = {}
for extended_column in self._parent_extended_columns[table_name]:
if extended_column in valid_columns:
numerical_distributions[extended_column] = DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION
synthesizer._set_numerical_distributions(numerical_distributions)

def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc):
"""Generate the extension columns for this child table.
Expand Down Expand Up @@ -298,17 +308,21 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc
index = []
scale_columns = None
pbar_args = self._get_pbar_args(desc=progress_bar_desc)

for foreign_key_value in tqdm(foreign_key_values, **pbar_args):
child_rows = child_table.loc[[foreign_key_value]]
child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)]

try:
if child_rows.empty:
row = pd.Series({'num_rows': len(child_rows)})
row.index = f'__{child_name}__{foreign_key}__' + row.index
else:
synthesizer = self._synthesizer(
table_meta, **self._table_parameters[child_name]
table_meta,
**self._table_parameters[child_name],
)
self._set_extended_columns_distributions(
synthesizer, child_name, child_rows.columns
)
synthesizer.fit_processed_data(child_rows.reset_index(drop=True))
row = synthesizer._get_parameters()
Expand Down Expand Up @@ -361,8 +375,8 @@ def _augment_table(self, table, tables, table_name):
"""
self._table_sizes[table_name] = len(table)
LOGGER.info('Computing extensions for table %s', table_name)
child_map = self.metadata._get_child_map()[table_name]
for child_name in child_map:
children = self.metadata._get_child_map()[table_name]
for child_name in children:
if child_name not in self._augmented_tables:
child_table = self._augment_table(tables[child_name], tables, child_name)
else:
Expand All @@ -386,15 +400,17 @@ def _augment_table(self, table, tables, table_name):
enforce_min_max_values=True
)
self.extended_columns[child_name][column].fit(extension, column)

table = table.merge(extension, how='left', right_index=True, left_index=True)
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
table[num_rows_key] = table[num_rows_key].fillna(0)
self._max_child_rows[num_rows_key] = table[num_rows_key].max()
self._min_child_rows[num_rows_key] = table[num_rows_key].min()

if len(extension.columns) > 0:
self._parent_extended_columns[table_name].extend(list(extension.columns))

tables[table_name] = table
self._learned_relationships += 1

self._augmented_tables.append(table_name)
self._clear_nans(table)

Expand Down Expand Up @@ -436,7 +452,6 @@ def _pop_foreign_keys(self, table_data, table_name):
keys = {}
for fk in foreign_keys:
keys[fk] = table_data.pop(fk).to_numpy()

return keys

def _model_tables(self, augmented_data):
Expand All @@ -462,6 +477,9 @@ def _model_tables(self, augmented_data):
)

if not table.empty:
self._set_extended_columns_distributions(
self._table_synthesizers[table_name], table_name, table.columns
)
self._table_synthesizers[table_name].fit_processed_data(table)
table_parameters = self._table_synthesizers[table_name]._get_parameters()
self._default_parameters[table_name] = {
Expand All @@ -470,8 +488,8 @@ def _model_tables(self, augmented_data):
if 'univariates' in parameter
}

for name, values in keys.items():
table[name] = values
for fk_column_name, fk_values in keys.items():
table[fk_column_name] = fk_values

def _extract_parameters(self, parent_row, table_name, foreign_key):
"""Get the params from a generated parent row.
Expand Down
5 changes: 2 additions & 3 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
Args:
table_name (string):
Name of the table to sample children for.
Name of the table (parent) to sample children for.
sampled_data (dict):
A dictionary mapping table names to sampled tables (pd.DataFrame).
"""
Expand Down Expand Up @@ -254,7 +254,7 @@ def _sample(self, scale=1.0):
"""Sample the entire dataset.
Returns a dictionary with all the tables of the dataset. The amount of rows sampled will
depend from table to table. This is because the children tables are created modelling the
depend from table to table. This is because the children tables are created modeling the
relation that they have with their parent tables, so its behavior may change from one
table to another.
Expand Down Expand Up @@ -305,5 +305,4 @@ def _sample(self, scale=1.0):
sampled_data[child_name], sampled_data[parent_name], child_name, parent_name
)
added_relationships.add((parent_name, child_name))

return self._finalize(sampled_data)
2 changes: 1 addition & 1 deletion sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _sample(self, scale=1.0):
"""Sample the entire dataset.
Returns a dictionary with all the tables of the dataset. The amount of rows sampled will
depend from table to table. This is because the children tables are created modelling the
depend from table to table. This is because the children tables are created modeling the
relation that they have with their parent tables, so its behavior may change from one
table to another.
Expand Down
3 changes: 0 additions & 3 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def _sample_rows(
"""
if self._model and not self._random_state_set:
self._set_random_state(FIXED_RNG_SEED)

need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns
if self._model and need_sample:
if conditions is None:
Expand All @@ -644,7 +643,6 @@ def _sample_rows(
raw_sampled = self._sample(num_rows, transformed_conditions)
except NotImplementedError:
raw_sampled = self._sample(num_rows)

sampled = self._data_processor.reverse_transform(raw_sampled)
if keep_extra_columns:
input_columns = self._data_processor._hyper_transformer._input_columns
Expand All @@ -655,7 +653,6 @@ def _sample_rows(

if previous_rows is not None:
sampled = pd.concat([previous_rows, sampled], ignore_index=True)

sampled = self._data_processor.filter_valid(sampled)

if conditions is not None:
Expand Down
11 changes: 7 additions & 4 deletions sdv/single_table/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,19 @@ def __init__(
locales=locales,
)
validate_numerical_distributions(numerical_distributions, self.metadata.columns)
self.numerical_distributions = numerical_distributions or {}
self.default_distribution = default_distribution or 'beta'

self.default_distribution = default_distribution or 'beta'
self._default_distribution = self.get_distribution_class(self.default_distribution)

self._set_numerical_distributions(numerical_distributions)
self._num_rows = None

def _set_numerical_distributions(self, numerical_distributions):
self.numerical_distributions = numerical_distributions or {}
self._numerical_distributions = {
field: self.get_distribution_class(distribution)
for field, distribution in self.numerical_distributions.items()
}
self._num_rows = None

def _fit(self, processed_data):
"""Fit the model to the table.
Expand All @@ -138,7 +142,6 @@ def _fit(self, processed_data):
numerical_distributions[column] = self._numerical_distributions.get(
column, self._default_distribution
)

self._model = multivariate.GaussianMultivariate(distribution=numerical_distributions)

with warnings.catch_warnings():
Expand Down
125 changes: 125 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import importlib.metadata
import logging
import math
import re
import warnings

Expand Down Expand Up @@ -1382,6 +1383,130 @@ def test_null_foreign_keys(self):
synthesizer.fit(data)


@pytest.mark.parametrize('num_rows', [(10), (1000)])
def test_hma_0_1_child(num_rows):
parent_table = pd.DataFrame(
data={
'id': list(range(num_rows)),
'col_A': list(np.random.choice(['A', 'B', 'C', 'D', 'E'], size=num_rows)),
}
)
child_table_data = {'parent_id': [], 'col_B': [], 'col_C': []}

for i in range(num_rows):
num_children = np.random.choice([0, 1, 10, 15], p=[0.4, 0.5, 0.05, 0.05])
if num_children == 0:
continue
child_table_data['parent_id'].extend([i] * num_children)
child_table_data['col_B'].extend([
round(i, 2) for i in np.random.uniform(low=0, high=10, size=num_children)
])
child_table_data['col_C'].extend(
list(np.random.choice(['A', 'B', 'C', 'D', 'E'], size=num_children))
)

data = {'parent': parent_table, 'child': pd.DataFrame(data=child_table_data)}
metadata = MultiTableMetadata.load_from_dict({
'tables': {
'parent': {
'primary_key': 'id',
'columns': {'id': {'sdtype': 'id'}, 'col_A': {'sdtype': 'categorical'}},
},
'child': {
'columns': {
'parent_id': {'sdtype': 'id'},
'col_B': {'sdtype': 'numerical'},
'col_C': {'sdtype': 'categorical'},
}
},
},
'relationships': [
{
'parent_table_name': 'parent',
'child_table_name': 'child',
'parent_primary_key': 'id',
'child_foreign_key': 'parent_id',
}
],
})
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
synthesizer.fit(data)
synthetic_data = synthesizer.sample(scale=1)
synthetic_child_df = synthetic_data['child']
data_col_max = synthetic_child_df['col_B'].max()
expected_constant_length = math.floor(len(synthetic_child_df) * 0.70)
actual_constants = synthetic_child_df[synthetic_child_df['col_B'] == data_col_max]
assert len(actual_constants) <= expected_constant_length
assert synthetic_child_df['col_B'].max() <= synthetic_child_df['col_B'].max()
assert synthetic_child_df['col_B'].min() >= synthetic_child_df['col_B'].min()


def test_hma_0_1_grandparent():
grandparent = pd.DataFrame({'grandparent_id': [50, 51, 52]})
parent = pd.DataFrame({
'parent_id': [0, 1, 2, 3],
'data': [1.5, 2.5, 5.9, 10.6],
'grandparent_id': [50, 50, 50, 52],
})
child = pd.DataFrame({
'child_id': [10, 11, 12],
'parent_id': [0, 1, 2],
'data': [1.8, 0.7, 2.5],
})
data = {'parent': parent, 'child': child, 'grandparent': grandparent}
metadata_dict = {
'tables': {
'grandparent': {
'primary_key': 'grandparent_id',
'columns': {
'grandparent_id': {'sdtype': 'id'},
},
},
'parent': {
'primary_key': 'parent_id',
'columns': {
'parent_id': {'sdtype': 'id'},
'data': {'sdtype': 'numerical'},
'grandparent_id': {'sdtype': 'id'},
},
},
'child': {
'primary_key': 'child_id',
'columns': {
'child_id': {'sdtype': 'id'},
'parent_id': {'sdtype': 'id'},
'data': {'sdtype': 'numerical'},
},
},
},
'relationships': [
{
'parent_table_name': 'grandparent',
'parent_primary_key': 'grandparent_id',
'child_table_name': 'parent',
'child_foreign_key': 'grandparent_id',
},
{
'parent_table_name': 'parent',
'parent_primary_key': 'parent_id',
'child_table_name': 'child',
'child_foreign_key': 'parent_id',
},
],
}
metadata = MultiTableMetadata().load_from_dict(metadata_dict)
metadata.validate()
metadata.validate_data(data)
synthesizer = HMASynthesizer(metadata=metadata, verbose=False)
synthesizer.fit(data)
synthetic_data = synthesizer.sample()
child_df = synthetic_data['child']
data_col_max = child_df['data'].max()
data_col_min = child_df['data'].min()
assert child_df[child_df['data'] == data_col_max].shape[0] == 2
assert child_df[child_df['data'] == data_col_min].shape[0] == 1


parametrization = [
('update_column', {'table_name': 'departure', 'column_name': 'city', 'sdtype': 'categorical'}),
('set_primary_key', {'table_name': 'arrival', 'column_name': 'id_flight'}),
Expand Down

0 comments on commit e06ceca

Please sign in to comment.