Skip to content

Commit

Permalink
Replace MultiTableMetadata with Metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Jul 25, 2024
1 parent d896bec commit bbd7667
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 54 deletions.
8 changes: 4 additions & 4 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from sdv.metadata import MultiTableMetadata
from sdv.metadata import Metadata


class BaseLocalHandler:
Expand All @@ -25,11 +25,11 @@ def create_metadata(self, data):
Dictionary of table names to dataframes.
Returns:
MultiTableMetadata:
An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata
Metadata:
An ``sdv.metadata.Metadata`` object with the detected metadata
properties from the data.
"""
metadata = MultiTableMetadata()
metadata = Metadata()
metadata.detect_from_dataframes(data)
return metadata

Expand Down
4 changes: 4 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
LOGGER = logging.getLogger(__name__)
MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata')
WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format']
DEPRECATION_MSG = (
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class MultiTableMetadata:
Expand Down
2 changes: 1 addition & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def upgrade_metadata(cls, filepath):
if len(tables) > 1:
raise InvalidMetadataError(
'There are multiple tables specified in the JSON. '
'Try using the MultiTableMetadata class to upgrade this file.'
'Try using the Metadata class to upgrade this file.'
)

else:
Expand Down
12 changes: 9 additions & 3 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import DEPRECATION_MSG, MultiTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand All @@ -38,8 +40,8 @@ class BaseMultiTableSynthesizer:
multi table synthesizers need to implement, as well as common functionality.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
Multi table metadata representing the data tables that this synthesizer will be used
metadata (sdv.metadata.Metadata):
Table metadata representing the data tables that this synthesizer will be used
for.
locales (list or str):
The default locale(s) to use for AnonymizedFaker transformers.
Expand Down Expand Up @@ -71,8 +73,9 @@ def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
metadata = Metadata.load_from_dict(table_metadata.to_dict())
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, locales=self.locales, **synthesizer_parameters
metadata=metadata, locales=self.locales, **synthesizer_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name

Expand All @@ -97,6 +100,9 @@ def _check_metadata_updated(self):

def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.metadata = metadata
if type(metadata) is MultiTableMetadata:
self.metadata = Metadata().load_from_dict(metadata.to_dict())
warnings.warn(DEPRECATION_MSG, FutureWarning)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=r'.*column relationship.*')
self.metadata.validate()
Expand Down
4 changes: 2 additions & 2 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
"""Hierarchical Modeling Algorithm One.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi table metadata representing the data tables that this synthesizer will be used
for.
locales (list or str):
Expand All @@ -47,7 +47,7 @@ def _get_num_data_columns(metadata):
"""Get the number of data columns, ie colums that are not id, for each table.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
"""
columns_per_table = {}
Expand Down
26 changes: 13 additions & 13 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _simplify_relationships_and_tables(metadata, tables_to_drop):
Removes the tables that are not direct child or grandchild of the root table.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
tables_to_drop (set):
Set of the tables that relationships will be removed.
Expand All @@ -149,7 +149,7 @@ def _simplify_grandchildren(metadata, grandchildren):
- Drop all modelables columns.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
grandchildren (set):
Set of the grandchildren of the root table.
Expand All @@ -174,7 +174,7 @@ def _get_num_column_to_drop(metadata, child_table, max_col_per_relationships):
- minimum number of column to drop = n + k - sqrt(k^2 + 1 + 2m)
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
child_table (str):
Name of the child table.
Expand Down Expand Up @@ -232,7 +232,7 @@ def _simplify_child(metadata, child_table, max_col_per_relationships):
"""Simplify the child table.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
child_table (str):
Name of the child table.
Expand All @@ -252,7 +252,7 @@ def _simplify_children(metadata, children, root_table, num_data_column):
- Drop some modelable columns to have at most MAX_NUMBER_OF_COLUMNS columns to model.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
children (set):
Set of the children of the root table.
Expand Down Expand Up @@ -288,11 +288,11 @@ def _simplify_metadata(metadata):
- Drop some modelable columns in the children to have at most 1000 columns to model.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
Returns:
MultiTableMetadata:
Metadata:
Simplified metadata.
"""
simplified_metadata = deepcopy(metadata)
Expand Down Expand Up @@ -330,7 +330,7 @@ def _simplify_data(data, metadata):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
Returns:
Expand Down Expand Up @@ -375,7 +375,7 @@ def _get_rows_to_drop(data, metadata):
This ensures that we preserve the referential integrity between all the relationships.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
Expand Down Expand Up @@ -470,7 +470,7 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
table (str):
Name of the table.
Expand All @@ -496,7 +496,7 @@ def _get_primary_keys_referenced(data, metadata):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
Returns:
Expand Down Expand Up @@ -568,7 +568,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
table (str):
Name of the table.
Expand Down Expand Up @@ -604,7 +604,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
referenced by the descendants and some unreferenced rows.
Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BaseIndependentSampler:
"""Independent sampler mixin.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi-table metadata representing the data tables that this sampler will be used for.
table_synthesizers (dict):
Dictionary mapping each table to a synthesizer. Should be instantiated and passed to
Expand Down
6 changes: 3 additions & 3 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def simplify_schema(data, metadata, verbose=True):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
verbose (bool):
If True, print information about the simplification process.
Expand All @@ -50,7 +50,7 @@ def simplify_schema(data, metadata, verbose=True):
tuple:
dict:
Dictionary with the simplified dataframes.
MultiTableMetadata:
Metadata:
Simplified metadata.
"""
try:
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
main_table_name (str):
Name of the main table.
Expand Down
2 changes: 1 addition & 1 deletion sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
Expand Down
41 changes: 41 additions & 0 deletions tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import re

import pytest

from sdv.datasets.demo import download_demo
from sdv.metadata.metadata import DEFAULT_TABLE_NAME, Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.multi_table.hma import HMASynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer


Expand Down Expand Up @@ -255,3 +258,41 @@ def test_single_table_compatibility(tmp_path):
metadata_sample = synthesizer.sample(10)
assert loaded_synthesizer.metadata.to_dict() == synthesizer_2.metadata.to_dict()
assert metadata_sample.columns.to_list() == loaded_sample.columns.to_list()


def test_multi_table_compatibility(tmp_path):
"""Test if SingleMetadataTable still has compatibility with single table synthesizers."""
# Setup
data, metadata = download_demo('multi_table', 'fake_hotels')
warn_msg = re.escape(
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)

# Run
with pytest.warns(FutureWarning, match=warn_msg):
synthesizer = HMASynthesizer(metadata)
synthesizer.fit(data)
model_path = tmp_path / 'synthesizer.pkl'
synthesizer.save(model_path)

# Assert
assert model_path.exists()
assert model_path.is_file()
loaded_synthesizer = HMASynthesizer.load(model_path)
assert isinstance(synthesizer, HMASynthesizer)
assert loaded_synthesizer.get_info() == synthesizer.get_info()
assert isinstance(loaded_synthesizer.metadata, Metadata)
expected_metadata = metadata.to_dict()
expected_metadata['METADATA_SPEC_VERSION'] = 'V1'
assert loaded_synthesizer.metadata.to_dict() == expected_metadata
loaded_sample = loaded_synthesizer.sample(10)
synthesizer.validate(loaded_sample)

# Run against Metadata
synthesizer_2 = HMASynthesizer(Metadata._convert_to_unified_metadata(metadata))
synthesizer_2.fit(data)
metadata_sample = synthesizer.sample(10)
assert loaded_synthesizer.metadata.to_dict() == synthesizer_2.metadata.to_dict()
for table in metadata_sample:
assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list()
28 changes: 10 additions & 18 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sdv.datasets.local import load_csvs
from sdv.errors import SamplingError, SynthesizerInputError, VersionError
from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.multi_table import HMASynthesizer
from tests.integration.single_table.custom_constraints import MyConstraint
Expand Down Expand Up @@ -393,6 +394,7 @@ def test_save_and_load(self, tmp_path):
"""Test saving and loading a multi-table synthesizer."""
# Setup
_, _, metadata = self.get_custom_constraint_data_and_metadata()
metadata = Metadata.load_from_dict(metadata.to_dict())
synthesizer = HMASynthesizer(metadata)
model_path = tmp_path / 'synthesizer.pkl'

Expand Down Expand Up @@ -460,6 +462,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path):
"""
# Loading the demo data
real_data, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels')
metadata = Metadata.load_from_dict(metadata.to_dict())

# Creating a Synthesizer
synthesizer = HMASynthesizer(metadata)
Expand Down Expand Up @@ -1250,9 +1253,8 @@ def test_metadata_updated_no_warning(self, tmp_path):
instance.fit(data)

# Assert
for warning in captured_warnings:
assert warning.category is FutureWarning
assert len(captured_warnings) == 3
assert len(captured_warnings) == 1
assert captured_warnings[0].category is FutureWarning

# Run 2
metadata_detect = MultiTableMetadata()
Expand All @@ -1271,9 +1273,8 @@ def test_metadata_updated_no_warning(self, tmp_path):
instance.fit(data)

# Assert
for warning in captured_warnings:
assert warning.category is FutureWarning
assert len(captured_warnings) == 3
assert len(captured_warnings) == 1
assert captured_warnings[0].category is FutureWarning

# Run 3
instance = HMASynthesizer(metadata_detect)
Expand All @@ -1297,7 +1298,7 @@ def test_metadata_updated_warning_detect(self):
"""
# Setup
data, metadata = download_demo('multi_table', 'got_families')
metadata_detect = MultiTableMetadata()
metadata_detect = Metadata()
metadata_detect.detect_from_dataframes(data)

metadata_detect.relationships = metadata.relationships
Expand All @@ -1316,16 +1317,7 @@ def test_metadata_updated_warning_detect(self):
instance.fit(data)

# Assert
future_warnings = 0
user_warnings = 0
for warning in record:
if warning.category is FutureWarning:
future_warnings += 1
if warning.category is UserWarning:
user_warnings += 1
assert future_warnings == 3
assert user_warnings == 1
assert len(record) == 4
assert len(record) == 1

def test_null_foreign_keys(self):
"""Test that the synthesizer crashes when there are null foreign keys."""
Expand Down Expand Up @@ -1595,7 +1587,7 @@ def test_metadata_updated_warning(method, kwargs):
The warning should be raised during synthesizer initialization.
"""
metadata = MultiTableMetadata().load_from_dict({
metadata = Metadata().load_from_dict({
'tables': {
'departure': {
'primary_key': 'id',
Expand Down
Loading

0 comments on commit bbd7667

Please sign in to comment.