From e06cecac462712d46b0215bafc8d03c4bc40a13e Mon Sep 17 00:00:00 2001 From: Gaurav Sheni Date: Thu, 20 Jun 2024 10:01:47 -0400 Subject: [PATCH] Set numerical distributions to truncnorm for extension columns (#2068) --- sdv/multi_table/hma.py | 40 +++++-- sdv/sampling/hierarchical_sampler.py | 5 +- sdv/sampling/independent_sampler.py | 2 +- sdv/single_table/base.py | 3 - sdv/single_table/copulas.py | 11 +- tests/integration/multi_table/test_hma.py | 125 ++++++++++++++++++++++ 6 files changed, 164 insertions(+), 22 deletions(-) diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index d7a4712db..722b82ffa 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -1,6 +1,7 @@ """Hierarchical Modeling Algorithms.""" import logging +from collections import defaultdict from copy import deepcopy import numpy as np @@ -15,6 +16,7 @@ LOGGER = logging.getLogger(__name__) MAX_NUMBER_OF_COLUMNS = 1000 +DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION = 'truncnorm' class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): @@ -159,6 +161,7 @@ def __init__(self, metadata, locales=['en_US'], verbose=True): self._augmented_tables = [] self._learned_relationships = 0 self._default_parameters = {} + self._parent_extended_columns = defaultdict(list) self.verbose = verbose BaseHierarchicalSampler.__init__( self, self.metadata, self._table_synthesizers, self._table_sizes @@ -215,8 +218,8 @@ def _get_distributions(self): distributions = {} for table in self.metadata.tables: parameters = self.get_table_parameters(table) - sythesizer_parameter = parameters.get('synthesizer_parameters', {}) - distributions[table] = sythesizer_parameter.get('default_distribution', None) + synthesizer_parameters = parameters.get('synthesizer_parameters', {}) + distributions[table] = synthesizer_parameters.get('default_distribution', None) return distributions @@ -268,6 +271,13 @@ def preprocess(self, data): return processed_data + def _set_extended_columns_distributions(self, synthesizer, table_name, valid_columns): + numerical_distributions = {} + for extended_column in self._parent_extended_columns[table_name]: + if extended_column in valid_columns: + numerical_distributions[extended_column] = DEFAULT_EXTENDED_COLUMNS_DISTRIBUTION + synthesizer._set_numerical_distributions(numerical_distributions) + def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc): """Generate the extension columns for this child table. @@ -298,17 +308,21 @@ def _get_extension(self, child_name, child_table, foreign_key, progress_bar_desc index = [] scale_columns = None pbar_args = self._get_pbar_args(desc=progress_bar_desc) + for foreign_key_value in tqdm(foreign_key_values, **pbar_args): child_rows = child_table.loc[[foreign_key_value]] child_rows = child_rows[child_rows.columns.difference(foreign_key_columns)] - try: if child_rows.empty: row = pd.Series({'num_rows': len(child_rows)}) row.index = f'__{child_name}__{foreign_key}__' + row.index else: synthesizer = self._synthesizer( - table_meta, **self._table_parameters[child_name] + table_meta, + **self._table_parameters[child_name], + ) + self._set_extended_columns_distributions( + synthesizer, child_name, child_rows.columns ) synthesizer.fit_processed_data(child_rows.reset_index(drop=True)) row = synthesizer._get_parameters() @@ -361,8 +375,8 @@ def _augment_table(self, table, tables, table_name): """ self._table_sizes[table_name] = len(table) LOGGER.info('Computing extensions for table %s', table_name) - child_map = self.metadata._get_child_map()[table_name] - for child_name in child_map: + children = self.metadata._get_child_map()[table_name] + for child_name in children: if child_name not in self._augmented_tables: child_table = self._augment_table(tables[child_name], tables, child_name) else: @@ -386,15 +400,17 @@ def _augment_table(self, table, tables, table_name): enforce_min_max_values=True ) self.extended_columns[child_name][column].fit(extension, column) - table = table.merge(extension, how='left', right_index=True, left_index=True) num_rows_key = f'__{child_name}__{foreign_key}__num_rows' table[num_rows_key] = table[num_rows_key].fillna(0) self._max_child_rows[num_rows_key] = table[num_rows_key].max() self._min_child_rows[num_rows_key] = table[num_rows_key].min() + + if len(extension.columns) > 0: + self._parent_extended_columns[table_name].extend(list(extension.columns)) + tables[table_name] = table self._learned_relationships += 1 - self._augmented_tables.append(table_name) self._clear_nans(table) @@ -436,7 +452,6 @@ def _pop_foreign_keys(self, table_data, table_name): keys = {} for fk in foreign_keys: keys[fk] = table_data.pop(fk).to_numpy() - return keys def _model_tables(self, augmented_data): @@ -462,6 +477,9 @@ def _model_tables(self, augmented_data): ) if not table.empty: + self._set_extended_columns_distributions( + self._table_synthesizers[table_name], table_name, table.columns + ) self._table_synthesizers[table_name].fit_processed_data(table) table_parameters = self._table_synthesizers[table_name]._get_parameters() self._default_parameters[table_name] = { @@ -470,8 +488,8 @@ def _model_tables(self, augmented_data): if 'univariates' in parameter } - for name, values in keys.items(): - table[name] = values + for fk_column_name, fk_values in keys.items(): + table[fk_column_name] = fk_values def _extract_parameters(self, parent_row, table_name, foreign_key): """Get the params from a generated parent row. diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 7ba175904..9662bcb99 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -191,7 +191,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0): Args: table_name (string): - Name of the table to sample children for. + Name of the table (parent) to sample children for. sampled_data (dict): A dictionary mapping table names to sampled tables (pd.DataFrame). """ @@ -254,7 +254,7 @@ def _sample(self, scale=1.0): """Sample the entire dataset. Returns a dictionary with all the tables of the dataset. The amount of rows sampled will - depend from table to table. This is because the children tables are created modelling the + depend from table to table. This is because the children tables are created modeling the relation that they have with their parent tables, so its behavior may change from one table to another. @@ -305,5 +305,4 @@ def _sample(self, scale=1.0): sampled_data[child_name], sampled_data[parent_name], child_name, parent_name ) added_relationships.add((parent_name, child_name)) - return self._finalize(sampled_data) diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index 22d0359b4..46bdc6d41 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -126,7 +126,7 @@ def _sample(self, scale=1.0): """Sample the entire dataset. Returns a dictionary with all the tables of the dataset. The amount of rows sampled will - depend from table to table. This is because the children tables are created modelling the + depend from table to table. This is because the children tables are created modeling the relation that they have with their parent tables, so its behavior may change from one table to another. diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 135e37f3b..bcc075ea5 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -634,7 +634,6 @@ def _sample_rows( """ if self._model and not self._random_state_set: self._set_random_state(FIXED_RNG_SEED) - need_sample = self._data_processor.get_sdtypes(primary_keys=False) or keep_extra_columns if self._model and need_sample: if conditions is None: @@ -644,7 +643,6 @@ def _sample_rows( raw_sampled = self._sample(num_rows, transformed_conditions) except NotImplementedError: raw_sampled = self._sample(num_rows) - sampled = self._data_processor.reverse_transform(raw_sampled) if keep_extra_columns: input_columns = self._data_processor._hyper_transformer._input_columns @@ -655,7 +653,6 @@ def _sample_rows( if previous_rows is not None: sampled = pd.concat([previous_rows, sampled], ignore_index=True) - sampled = self._data_processor.filter_valid(sampled) if conditions is not None: diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 43aca493a..26167a946 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -110,15 +110,19 @@ def __init__( locales=locales, ) validate_numerical_distributions(numerical_distributions, self.metadata.columns) - self.numerical_distributions = numerical_distributions or {} - self.default_distribution = default_distribution or 'beta' + self.default_distribution = default_distribution or 'beta' self._default_distribution = self.get_distribution_class(self.default_distribution) + + self._set_numerical_distributions(numerical_distributions) + self._num_rows = None + + def _set_numerical_distributions(self, numerical_distributions): + self.numerical_distributions = numerical_distributions or {} self._numerical_distributions = { field: self.get_distribution_class(distribution) for field, distribution in self.numerical_distributions.items() } - self._num_rows = None def _fit(self, processed_data): """Fit the model to the table. @@ -138,7 +142,6 @@ def _fit(self, processed_data): numerical_distributions[column] = self._numerical_distributions.get( column, self._default_distribution ) - self._model = multivariate.GaussianMultivariate(distribution=numerical_distributions) with warnings.catch_warnings(): diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index a82049200..8b337cf68 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1,6 +1,7 @@ import datetime import importlib.metadata import logging +import math import re import warnings @@ -1382,6 +1383,130 @@ def test_null_foreign_keys(self): synthesizer.fit(data) +@pytest.mark.parametrize('num_rows', [(10), (1000)]) +def test_hma_0_1_child(num_rows): + parent_table = pd.DataFrame( + data={ + 'id': list(range(num_rows)), + 'col_A': list(np.random.choice(['A', 'B', 'C', 'D', 'E'], size=num_rows)), + } + ) + child_table_data = {'parent_id': [], 'col_B': [], 'col_C': []} + + for i in range(num_rows): + num_children = np.random.choice([0, 1, 10, 15], p=[0.4, 0.5, 0.05, 0.05]) + if num_children == 0: + continue + child_table_data['parent_id'].extend([i] * num_children) + child_table_data['col_B'].extend([ + round(i, 2) for i in np.random.uniform(low=0, high=10, size=num_children) + ]) + child_table_data['col_C'].extend( + list(np.random.choice(['A', 'B', 'C', 'D', 'E'], size=num_children)) + ) + + data = {'parent': parent_table, 'child': pd.DataFrame(data=child_table_data)} + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'parent': { + 'primary_key': 'id', + 'columns': {'id': {'sdtype': 'id'}, 'col_A': {'sdtype': 'categorical'}}, + }, + 'child': { + 'columns': { + 'parent_id': {'sdtype': 'id'}, + 'col_B': {'sdtype': 'numerical'}, + 'col_C': {'sdtype': 'categorical'}, + } + }, + }, + 'relationships': [ + { + 'parent_table_name': 'parent', + 'child_table_name': 'child', + 'parent_primary_key': 'id', + 'child_foreign_key': 'parent_id', + } + ], + }) + synthesizer = HMASynthesizer(metadata=metadata, verbose=False) + synthesizer.fit(data) + synthetic_data = synthesizer.sample(scale=1) + synthetic_child_df = synthetic_data['child'] + data_col_max = synthetic_child_df['col_B'].max() + expected_constant_length = math.floor(len(synthetic_child_df) * 0.70) + actual_constants = synthetic_child_df[synthetic_child_df['col_B'] == data_col_max] + assert len(actual_constants) <= expected_constant_length + assert synthetic_child_df['col_B'].max() <= synthetic_child_df['col_B'].max() + assert synthetic_child_df['col_B'].min() >= synthetic_child_df['col_B'].min() + + +def test_hma_0_1_grandparent(): + grandparent = pd.DataFrame({'grandparent_id': [50, 51, 52]}) + parent = pd.DataFrame({ + 'parent_id': [0, 1, 2, 3], + 'data': [1.5, 2.5, 5.9, 10.6], + 'grandparent_id': [50, 50, 50, 52], + }) + child = pd.DataFrame({ + 'child_id': [10, 11, 12], + 'parent_id': [0, 1, 2], + 'data': [1.8, 0.7, 2.5], + }) + data = {'parent': parent, 'child': child, 'grandparent': grandparent} + metadata_dict = { + 'tables': { + 'grandparent': { + 'primary_key': 'grandparent_id', + 'columns': { + 'grandparent_id': {'sdtype': 'id'}, + }, + }, + 'parent': { + 'primary_key': 'parent_id', + 'columns': { + 'parent_id': {'sdtype': 'id'}, + 'data': {'sdtype': 'numerical'}, + 'grandparent_id': {'sdtype': 'id'}, + }, + }, + 'child': { + 'primary_key': 'child_id', + 'columns': { + 'child_id': {'sdtype': 'id'}, + 'parent_id': {'sdtype': 'id'}, + 'data': {'sdtype': 'numerical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'grandparent', + 'parent_primary_key': 'grandparent_id', + 'child_table_name': 'parent', + 'child_foreign_key': 'grandparent_id', + }, + { + 'parent_table_name': 'parent', + 'parent_primary_key': 'parent_id', + 'child_table_name': 'child', + 'child_foreign_key': 'parent_id', + }, + ], + } + metadata = MultiTableMetadata().load_from_dict(metadata_dict) + metadata.validate() + metadata.validate_data(data) + synthesizer = HMASynthesizer(metadata=metadata, verbose=False) + synthesizer.fit(data) + synthetic_data = synthesizer.sample() + child_df = synthetic_data['child'] + data_col_max = child_df['data'].max() + data_col_min = child_df['data'].min() + assert child_df[child_df['data'] == data_col_max].shape[0] == 2 + assert child_df[child_df['data'] == data_col_min].shape[0] == 1 + + parametrization = [ ('update_column', {'table_name': 'departure', 'column_name': 'city', 'sdtype': 'categorical'}), ('set_primary_key', {'table_name': 'arrival', 'column_name': 'id_flight'}),