Skip to content

Commit

Permalink
Replace MultiTableMetadata with Metadata (#2146)
Browse files Browse the repository at this point in the history
Co-authored-by: gsheni <[email protected]>
  • Loading branch information
lajohn4747 and gsheni authored Aug 1, 2024
1 parent a3ed4c8 commit 034802b
Show file tree
Hide file tree
Showing 16 changed files with 179 additions and 62 deletions.
8 changes: 4 additions & 4 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from sdv.metadata import MultiTableMetadata
from sdv.metadata import Metadata


class BaseLocalHandler:
Expand All @@ -25,11 +25,11 @@ def create_metadata(self, data):
Dictionary of table names to dataframes.
Returns:
MultiTableMetadata:
An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata
Metadata:
An ``sdv.metadata.Metadata`` object with the detected metadata
properties from the data.
"""
metadata = MultiTableMetadata()
metadata = Metadata()
metadata.detect_from_dataframes(data)
return metadata

Expand Down
4 changes: 4 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
LOGGER = logging.getLogger(__name__)
MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata')
WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format']
DEPRECATION_MSG = (
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class MultiTableMetadata:
Expand Down
2 changes: 1 addition & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def upgrade_metadata(cls, filepath):
if len(tables) > 1:
raise InvalidMetadataError(
'There are multiple tables specified in the JSON. '
'Try using the MultiTableMetadata class to upgrade this file.'
'Try using the Metadata class to upgrade this file.'
)

else:
Expand Down
13 changes: 9 additions & 4 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import DEPRECATION_MSG, MultiTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand All @@ -38,9 +40,8 @@ class BaseMultiTableSynthesizer:
multi table synthesizers need to implement, as well as common functionality.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
Multi table metadata representing the data tables that this synthesizer will be used
for.
metadata (sdv.metadata.Metadata):
Metadata representing the data tables that this synthesizer will utilize.
locales (list or str):
The default locale(s) to use for AnonymizedFaker transformers.
Defaults to ``['en_US']``.
Expand Down Expand Up @@ -71,8 +72,9 @@ def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
metadata = Metadata.load_from_dict(table_metadata.to_dict())
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, locales=self.locales, **synthesizer_parameters
metadata=metadata, locales=self.locales, **synthesizer_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name

Expand All @@ -97,6 +99,9 @@ def _check_metadata_updated(self):

def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.metadata = metadata
if type(metadata) is MultiTableMetadata:
self.metadata = Metadata().load_from_dict(metadata.to_dict())
warnings.warn(DEPRECATION_MSG, FutureWarning)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=r'.*column relationship.*')
self.metadata.validate()
Expand Down
4 changes: 2 additions & 2 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
"""Hierarchical Modeling Algorithm One.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi table metadata representing the data tables that this synthesizer will be used
for.
locales (list or str):
Expand All @@ -47,7 +47,7 @@ def _get_num_data_columns(metadata):
"""Get the number of data columns, ie colums that are not id, for each table.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
"""
columns_per_table = {}
Expand Down
26 changes: 13 additions & 13 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _simplify_relationships_and_tables(metadata, tables_to_drop):
Removes the tables that are not direct child or grandchild of the root table.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
tables_to_drop (set):
Set of the tables that relationships will be removed.
Expand All @@ -149,7 +149,7 @@ def _simplify_grandchildren(metadata, grandchildren):
- Drop all modelables columns.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
grandchildren (set):
Set of the grandchildren of the root table.
Expand All @@ -174,7 +174,7 @@ def _get_num_column_to_drop(metadata, child_table, max_col_per_relationships):
- minimum number of column to drop = n + k - sqrt(k^2 + 1 + 2m)
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
child_table (str):
Name of the child table.
Expand Down Expand Up @@ -232,7 +232,7 @@ def _simplify_child(metadata, child_table, max_col_per_relationships):
"""Simplify the child table.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
child_table (str):
Name of the child table.
Expand All @@ -252,7 +252,7 @@ def _simplify_children(metadata, children, root_table, num_data_column):
- Drop some modelable columns to have at most MAX_NUMBER_OF_COLUMNS columns to model.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
children (set):
Set of the children of the root table.
Expand Down Expand Up @@ -288,11 +288,11 @@ def _simplify_metadata(metadata):
- Drop some modelable columns in the children to have at most 1000 columns to model.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
Returns:
MultiTableMetadata:
Metadata:
Simplified metadata.
"""
simplified_metadata = deepcopy(metadata)
Expand Down Expand Up @@ -330,7 +330,7 @@ def _simplify_data(data, metadata):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
Returns:
Expand Down Expand Up @@ -375,7 +375,7 @@ def _get_rows_to_drop(data, metadata):
This ensures that we preserve the referential integrity between all the relationships.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
Expand Down Expand Up @@ -470,7 +470,7 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
table (str):
Name of the table.
Expand All @@ -496,7 +496,7 @@ def _get_primary_keys_referenced(data, metadata):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
Returns:
Expand Down Expand Up @@ -568,7 +568,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
table (str):
Name of the table.
Expand Down Expand Up @@ -604,7 +604,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
referenced by the descendants and some unreferenced rows.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BaseHierarchicalSampler:
"""Hierarchical sampler mixin.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi-table metadata representing the data tables that this sampler will be used for.
table_synthesizers (dict):
Dictionary mapping each table to a synthesizer. Should be instantiated and passed to
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BaseIndependentSampler:
"""Independent sampler mixin.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi-table metadata representing the data tables that this sampler will be used for.
table_synthesizers (dict):
Dictionary mapping each table to a synthesizer. Should be instantiated and passed to
Expand Down
6 changes: 3 additions & 3 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def simplify_schema(data, metadata, verbose=True):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
verbose (bool):
If True, print information about the simplification process.
Expand All @@ -50,7 +50,7 @@ def simplify_schema(data, metadata, verbose=True):
tuple:
dict:
Dictionary with the simplified dataframes.
MultiTableMetadata:
Metadata:
Simplified metadata.
"""
try:
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
main_table_name (str):
Name of the main table.
Expand Down
2 changes: 1 addition & 1 deletion sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
Expand Down
Loading

0 comments on commit 034802b

Please sign in to comment.