Skip to content

Commit

Permalink
Missing error message if the user forgets to add a sequence_key when…
Browse files Browse the repository at this point in the history
… using PARSynthesizer (#1909)
  • Loading branch information
frances-h authored Apr 10, 2024
1 parent e480c3e commit f758807
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 56 deletions.
6 changes: 3 additions & 3 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False

sequence_key = self.metadata.sequence_key
self._sequence_key = list(_cast_to_iterable(sequence_key)) if sequence_key else None
if context_columns and not self._sequence_key:
if not self._sequence_key:
raise SynthesizerInputError(
"No 'sequence_keys' are specified in the metadata. The PARSynthesizer cannot "
"model 'context_columns' in this case."
'The PARSythesizer is designed for multi-sequence data, identifiable through a '
'sequence key. Your metadata does not include a sequence key.'
)

self._sequence_index = self.metadata.sequence_index
Expand Down
70 changes: 17 additions & 53 deletions tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sdv.sampling import Condition
from sdv.sequential.par import PARSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from tests.utils import DataFrameMatcher


class TestPARSynthesizer:
Expand Down Expand Up @@ -82,18 +81,19 @@ def test___init__(self):
'name': {'sdtype': 'id'}
}

def test___init___context_columns_no_sequence_key(self):
"""Test when there are context columns but no sequence keys.
def test___init___no_sequence_key(self):
"""Test when there are no sequence keys.
If there are context columns and no sequence keys then an error should be raised.
If there are no sequence keys then an error should be raised.
"""
# Setup
metadata = self.get_metadata(add_sequence_key=False)

# Run and Assert
error_message = (
"No 'sequence_keys' are specified in the metadata. The PARSynthesizer cannot "
"model 'context_columns' in this case."
'The PARSythesizer is designed for multi-sequence data, identifiable through a '
'sequence key. Your metadata does not include a sequence key.'

)
with pytest.raises(SynthesizerInputError, match=error_message):
PARSynthesizer(
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_add_constraints(self, warnings_mock):
def test_get_parameters(self):
"""Test that it returns every ``init`` parameter without the ``metadata``."""
# Setup
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(
metadata=metadata,
enforce_min_max_values=True,
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_get_parameters(self):
def test_get_metadata(self):
"""Test that it returns the ``metadata`` object."""
# Setup
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(
metadata=metadata,
enforce_min_max_values=True,
Expand Down Expand Up @@ -593,25 +593,6 @@ def test__fit_with_sequence_key(self):
par._fit_context_model.assert_called_once_with(data)
par._fit_sequence_columns.assert_called_once_with(data)

def test__fit_without_sequence_key(self):
"""Test that the method doesn't fit the context synthesizer if there are no sequence keys.
If there are no sequence keys, then only the ``PARModel`` needs to be fit.
"""
# Setup
metadata = self.get_metadata(add_sequence_key=False)
par = PARSynthesizer(metadata=metadata)
data = self.get_data()
par._fit_context_model = Mock()
par._fit_sequence_columns = Mock()

# Run
par._fit(data)

# Assert
par._fit_context_model.assert_not_called()
par._fit_sequence_columns.assert_called_once_with(data)

def test_get_loss_values(self):
"""Test the ``get_loss_values`` method from ``PARSynthesizer."""
# Setup
Expand All @@ -621,7 +602,7 @@ def test_get_loss_values(self):
'Loss': [0.8, 0.6, 0.5]
})
mock_model.loss_values = loss_values
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(metadata)
instance._model = mock_model
instance._fitted = True
Expand All @@ -635,7 +616,7 @@ def test_get_loss_values(self):
def test_get_loss_values_error(self):
"""Test the ``get_loss_values`` errors if synthesizer has not been fitted."""
# Setup
metadata = SingleTableMetadata()
metadata = self.get_metadata()
instance = PARSynthesizer(metadata)

# Run / Assert
Expand All @@ -652,20 +633,19 @@ def test__sample_from_par(self, tqdm_mock):
sequences in a ``pandas.DataFrame``.
"""
# Setup
metadata = self.get_metadata(add_sequence_key=False)
metadata = self.get_metadata()
par = PARSynthesizer(metadata=metadata)
model_mock = Mock()
par._model = model_mock
par._data_columns = ['time', 'gender', 'name', 'measurement']
par._data_columns = ['time', 'gender', 'measurement']
par._output_columns = ['time', 'gender', 'name', 'measurement']
model_mock.sample_sequence.return_value = [
[18000, 20000, 22000],
[1, 1, 1],
[.4, .7, .1],
[55, 60, 65]
]
context_columns = pd.DataFrame(index=range(1))
tqdm_mock.tqdm.return_value = context_columns.iterrows()
context_columns = pd.DataFrame({'name': ['John Doe']})
tqdm_mock.tqdm.return_value = context_columns.set_index('name').iterrows()

# Run
sampled = par._sample_from_par(context_columns, 3)
Expand All @@ -675,14 +655,14 @@ def test__sample_from_par(self, tqdm_mock):
called_context_iterator_list = list(arg_list[0])
assert kwargs['disable'] is True
assert kwargs['total'] == 1
for i, row in enumerate(context_columns.iterrows()):
for i, row in enumerate(context_columns.set_index('name').iterrows()):
called_row = called_context_iterator_list[i]
pd.testing.assert_series_equal(row[1], called_row[1])

expected_output = pd.DataFrame({
'time': [18000, 20000, 22000],
'gender': [1, 1, 1],
'name': [.4, .7, .1],
'name': ['John Doe', 'John Doe', 'John Doe'],
'measurement': [55, 60, 65]
})
pd.testing.assert_frame_equal(sampled, expected_output)
Expand Down Expand Up @@ -845,22 +825,6 @@ def test_sample_sequence_key_needs_to_be_filled_in(self):
})
pd.testing.assert_frame_equal(context_columns, expected_context_columns, check_dtype=False)

def test_sample_no_sequence_key(self):
"""Test that if there is no sequence key, a column is made to substitute context."""
# Setup
metadata = self.get_metadata(add_sequence_key=False)
par = PARSynthesizer(
metadata=metadata
)
par._context_synthesizer = Mock()
par._sample = Mock()

# Run
par.sample(3, 2)

# Assert
par._sample.assert_called_once_with(DataFrameMatcher(pd.DataFrame(index=range(3))), 2)

def test_sample_sequential_columns(self):
"""Test that the method uses the provided context columns to sample."""
# Setup
Expand Down Expand Up @@ -910,7 +874,7 @@ def test_sample_sequential_columns_no_context_columns(self):
used.
"""
# Setup
par = PARSynthesizer(metadata=self.get_metadata(add_sequence_key=False))
par = PARSynthesizer(metadata=self.get_metadata())
par._sample = Mock()
context_columns = pd.DataFrame({
'gender': ['M', 'M', 'F']
Expand Down

0 comments on commit f758807

Please sign in to comment.