Skip to content

Commit

Permalink
fix based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Jul 30, 2024
1 parent f44f59c commit 7c28e15
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
3 changes: 1 addition & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ class BaseMultiTableSynthesizer:
Args:
metadata (sdv.metadata.Metadata):
Table metadata representing the data tables that this synthesizer will be used
for.
Metadata representing the data tables that this synthesizer will utilize.
locales (list or str):
The default locale(s) to use for AnonymizedFaker transformers.
Defaults to ``['en_US']``.
Expand Down
13 changes: 12 additions & 1 deletion tests/integration/metadata/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_single_table_compatibility(tmp_path):


def test_multi_table_compatibility(tmp_path):
"""Test if SingleMetadataTable still has compatibility with single table synthesizers."""
"""Test if MultiTableMetadata still has compatibility with multi table synthesizers."""
# Setup
data, _ = download_demo('multi_table', 'fake_hotels')
warn_msg = re.escape(
Expand Down Expand Up @@ -333,20 +333,31 @@ def test_multi_table_compatibility(tmp_path):
# Run
with pytest.warns(FutureWarning, match=warn_msg):
synthesizer = HMASynthesizer(metadata)

synthesizer.fit(data)
model_path = os.path.join(tmp_path, 'synthesizer.pkl')
synthesizer.save(model_path)

# Assert
assert os.path.exists(model_path)
assert os.path.isfile(model_path)

# Load HMASynthesizer
loaded_synthesizer = HMASynthesizer.load(model_path)

# Asserts
assert isinstance(synthesizer, HMASynthesizer)
assert loaded_synthesizer.get_info() == synthesizer.get_info()
assert isinstance(loaded_synthesizer.metadata, Metadata)

# Load Metadata
expected_metadata = metadata.to_dict()
expected_metadata['METADATA_SPEC_VERSION'] = 'V1'

# Asserts
assert loaded_synthesizer.metadata.to_dict() == expected_metadata

# Sample from loaded synthesizer
loaded_sample = loaded_synthesizer.sample(10)
synthesizer.validate(loaded_sample)

Expand Down

0 comments on commit 7c28e15

Please sign in to comment.