Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing error message if the user forgets to add a sequence_key when using PARSynthesizer #1909

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False

sequence_key = self.metadata.sequence_key
self._sequence_key = list(_cast_to_iterable(sequence_key)) if sequence_key else None
if context_columns and not self._sequence_key:
if not self._sequence_key:
raise SynthesizerInputError(
"No 'sequence_keys' are specified in the metadata. The PARSynthesizer cannot "
"model 'context_columns' in this case."
'The PARSythesizer is designed for multi-sequence data, identifiable through a '
'sequence key. Your metadata does not include a sequence key.'
)

self._sequence_index = self.metadata.sequence_index
Expand Down
70 changes: 17 additions & 53 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sdv.sampling import Condition
from sdv.sequential.par import PARSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from tests.utils import DataFrameMatcher


class TestPARSynthesizer:
Expand Down Expand Up @@ -82,18 +81,19 @@ def test___init__(self):
'name': {'sdtype': 'id'}
}

def test___init___context_columns_no_sequence_key(self):
"""Test when there are context columns but no sequence keys.
def test___init___no_sequence_key(self):
"""Test when there are no sequence keys.

If there are context columns and no sequence keys then an error should be raised.
If there are no sequence keys then an error should be raised.
"""
# Setup
metadata = self.get_metadata(add_sequence_key=False)

# Run and Assert
error_message = (
"No 'sequence_keys' are specified in the metadata. The PARSynthesizer cannot "
"model 'context_columns' in this case."
'The PARSythesizer is designed for multi-sequence data, identifiable through a '
'sequence key. Your metadata does not include a sequence key.'

)
with pytest.raises(SynthesizerInputError, match=error_message):
PARSynthesizer(
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_add_constraints(self, warnings_mock):
def test_get_parameters(self):
"""Test that it returns every ``init`` parameter without the ``metadata``."""
# Setup
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(
metadata=metadata,
enforce_min_max_values=True,
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_get_parameters(self):
def test_get_metadata(self):
"""Test that it returns the ``metadata`` object."""
# Setup
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(
metadata=metadata,
enforce_min_max_values=True,
Expand Down Expand Up @@ -593,25 +593,6 @@ def test__fit_with_sequence_key(self):
par._fit_context_model.assert_called_once_with(data)
par._fit_sequence_columns.assert_called_once_with(data)

def test__fit_without_sequence_key(self):
"""Test that the method doesn't fit the context synthesizer if there are no sequence keys.

If there are no sequence keys, then only the ``PARModel`` needs to be fit.
"""
# Setup
metadata = self.get_metadata(add_sequence_key=False)
par = PARSynthesizer(metadata=metadata)
data = self.get_data()
par._fit_context_model = Mock()
par._fit_sequence_columns = Mock()

# Run
par._fit(data)

# Assert
par._fit_context_model.assert_not_called()
par._fit_sequence_columns.assert_called_once_with(data)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``PARSynthesizer."""
# Setup
Expand All @@ -621,7 +602,7 @@ def test_get_loss_values(self):
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True
Expand All @@ -635,7 +616,7 @@ def test_get_loss_values(self):
def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(metadata)

# Run / Assert
Expand All @@ -652,20 +633,19 @@ def test__sample_from_par(self, tqdm_mock):
sequences in a ``pandas.DataFrame``.
"""
# Setup
metadata = self.get_metadata(add_sequence_key=False)
metadata = self.get_metadata()
par = PARSynthesizer(metadata=metadata)
model_mock = Mock()
par._model = model_mock
par._data_columns = ['time', 'gender', 'name', 'measurement']
par._data_columns = ['time', 'gender', 'measurement']
par._output_columns = ['time', 'gender', 'name', 'measurement']
model_mock.sample_sequence.return_value = [
[18000, 20000, 22000],
[1, 1, 1],
[.4, .7, .1],
[55, 60, 65]
]
context_columns = pd.DataFrame(index=range(1))
tqdm_mock.tqdm.return_value = context_columns.iterrows()
context_columns = pd.DataFrame({'name': ['John Doe']})
tqdm_mock.tqdm.return_value = context_columns.set_index('name').iterrows()

# Run
sampled = par._sample_from_par(context_columns, 3)
Expand All @@ -675,14 +655,14 @@ def test__sample_from_par(self, tqdm_mock):
called_context_iterator_list = list(arg_list[0])
assert kwargs['disable'] is True
assert kwargs['total'] == 1
for i, row in enumerate(context_columns.iterrows()):
for i, row in enumerate(context_columns.set_index('name').iterrows()):
called_row = called_context_iterator_list[i]
pd.testing.assert_series_equal(row[1], called_row[1])

expected_output = pd.DataFrame({
'time': [18000, 20000, 22000],
'gender': [1, 1, 1],
'name': [.4, .7, .1],
'name': ['John Doe', 'John Doe', 'John Doe'],
'measurement': [55, 60, 65]
})
pd.testing.assert_frame_equal(sampled, expected_output)
Expand Down Expand Up @@ -845,22 +825,6 @@ def test_sample_sequence_key_needs_to_be_filled_in(self):
})
pd.testing.assert_frame_equal(context_columns, expected_context_columns, check_dtype=False)

def test_sample_no_sequence_key(self):
"""Test that if there is no sequence key, a column is made to substitute context."""
# Setup
metadata = self.get_metadata(add_sequence_key=False)
par = PARSynthesizer(
metadata=metadata
)
par._context_synthesizer = Mock()
par._sample = Mock()

# Run
par.sample(3, 2)

# Assert
par._sample.assert_called_once_with(DataFrameMatcher(pd.DataFrame(index=range(3))), 2)

def test_sample_sequential_columns(self):
"""Test that the method uses the provided context columns to sample."""
# Setup
Expand Down Expand Up @@ -910,7 +874,7 @@ def test_sample_sequential_columns_no_context_columns(self):
used.
"""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(add_sequence_key=False))
par = PARSynthesizer(metadata=self.get_metadata())
par._sample = Mock()
context_columns = pd.DataFrame({
'gender': ['M', 'M', 'F']
Expand Down
Loading