diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index ef63add34..e06c96864 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -97,10 +97,10 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False sequence_key = self.metadata.sequence_key self._sequence_key = list(_cast_to_iterable(sequence_key)) if sequence_key else None - if context_columns and not self._sequence_key: + if not self._sequence_key: raise SynthesizerInputError( - "No 'sequence_keys' are specified in the metadata. The PARSynthesizer cannot " - "model 'context_columns' in this case." + 'The PARSythesizer is designed for multi-sequence data, identifiable through a ' + 'sequence key. Your metadata does not include a sequence key.' ) self._sequence_index = self.metadata.sequence_index diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 69d613c92..b5173dd03 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -12,7 +12,6 @@ from sdv.sampling import Condition from sdv.sequential.par import PARSynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer -from tests.utils import DataFrameMatcher class TestPARSynthesizer: @@ -82,18 +81,19 @@ def test___init__(self): 'name': {'sdtype': 'id'} } - def test___init___context_columns_no_sequence_key(self): - """Test when there are context columns but no sequence keys. + def test___init___no_sequence_key(self): + """Test when there are no sequence keys. - If there are context columns and no sequence keys then an error should be raised. + If there are no sequence keys then an error should be raised. """ # Setup metadata = self.get_metadata(add_sequence_key=False) # Run and Assert error_message = ( - "No 'sequence_keys' are specified in the metadata. The PARSynthesizer cannot " - "model 'context_columns' in this case." + 'The PARSythesizer is designed for multi-sequence data, identifiable through a ' + 'sequence key. Your metadata does not include a sequence key.' + ) with pytest.raises(SynthesizerInputError, match=error_message): PARSynthesizer( @@ -130,7 +130,7 @@ def test_add_constraints(self, warnings_mock): def test_get_parameters(self): """Test that it returns every ``init`` parameter without the ``metadata``.""" # Setup - metadata = SingleTableMetadata() + metadata = self.get_metadata() instance = PARSynthesizer( metadata=metadata, enforce_min_max_values=True, @@ -164,7 +164,7 @@ def test_get_parameters(self): def test_get_metadata(self): """Test that it returns the ``metadata`` object.""" # Setup - metadata = SingleTableMetadata() + metadata = self.get_metadata() instance = PARSynthesizer( metadata=metadata, enforce_min_max_values=True, @@ -593,25 +593,6 @@ def test__fit_with_sequence_key(self): par._fit_context_model.assert_called_once_with(data) par._fit_sequence_columns.assert_called_once_with(data) - def test__fit_without_sequence_key(self): - """Test that the method doesn't fit the context synthesizer if there are no sequence keys. - - If there are no sequence keys, then only the ``PARModel`` needs to be fit. - """ - # Setup - metadata = self.get_metadata(add_sequence_key=False) - par = PARSynthesizer(metadata=metadata) - data = self.get_data() - par._fit_context_model = Mock() - par._fit_sequence_columns = Mock() - - # Run - par._fit(data) - - # Assert - par._fit_context_model.assert_not_called() - par._fit_sequence_columns.assert_called_once_with(data) - def test_get_loss_values(self): """Test the ``get_loss_values`` method from ``PARSynthesizer.""" # Setup @@ -621,7 +602,7 @@ def test_get_loss_values(self): 'Loss': [0.8, 0.6, 0.5] }) mock_model.loss_values = loss_values - metadata = SingleTableMetadata() + metadata = self.get_metadata() instance = PARSynthesizer(metadata) instance._model = mock_model instance._fitted = True @@ -635,7 +616,7 @@ def test_get_loss_values(self): def test_get_loss_values_error(self): """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = self.get_metadata() instance = PARSynthesizer(metadata) # Run / Assert @@ -652,20 +633,19 @@ def test__sample_from_par(self, tqdm_mock): sequences in a ``pandas.DataFrame``. """ # Setup - metadata = self.get_metadata(add_sequence_key=False) + metadata = self.get_metadata() par = PARSynthesizer(metadata=metadata) model_mock = Mock() par._model = model_mock - par._data_columns = ['time', 'gender', 'name', 'measurement'] + par._data_columns = ['time', 'gender', 'measurement'] par._output_columns = ['time', 'gender', 'name', 'measurement'] model_mock.sample_sequence.return_value = [ [18000, 20000, 22000], [1, 1, 1], - [.4, .7, .1], [55, 60, 65] ] - context_columns = pd.DataFrame(index=range(1)) - tqdm_mock.tqdm.return_value = context_columns.iterrows() + context_columns = pd.DataFrame({'name': ['John Doe']}) + tqdm_mock.tqdm.return_value = context_columns.set_index('name').iterrows() # Run sampled = par._sample_from_par(context_columns, 3) @@ -675,14 +655,14 @@ def test__sample_from_par(self, tqdm_mock): called_context_iterator_list = list(arg_list[0]) assert kwargs['disable'] is True assert kwargs['total'] == 1 - for i, row in enumerate(context_columns.iterrows()): + for i, row in enumerate(context_columns.set_index('name').iterrows()): called_row = called_context_iterator_list[i] pd.testing.assert_series_equal(row[1], called_row[1]) expected_output = pd.DataFrame({ 'time': [18000, 20000, 22000], 'gender': [1, 1, 1], - 'name': [.4, .7, .1], + 'name': ['John Doe', 'John Doe', 'John Doe'], 'measurement': [55, 60, 65] }) pd.testing.assert_frame_equal(sampled, expected_output) @@ -845,22 +825,6 @@ def test_sample_sequence_key_needs_to_be_filled_in(self): }) pd.testing.assert_frame_equal(context_columns, expected_context_columns, check_dtype=False) - def test_sample_no_sequence_key(self): - """Test that if there is no sequence key, a column is made to substitute context.""" - # Setup - metadata = self.get_metadata(add_sequence_key=False) - par = PARSynthesizer( - metadata=metadata - ) - par._context_synthesizer = Mock() - par._sample = Mock() - - # Run - par.sample(3, 2) - - # Assert - par._sample.assert_called_once_with(DataFrameMatcher(pd.DataFrame(index=range(3))), 2) - def test_sample_sequential_columns(self): """Test that the method uses the provided context columns to sample.""" # Setup @@ -910,7 +874,7 @@ def test_sample_sequential_columns_no_context_columns(self): used. """ # Setup - par = PARSynthesizer(metadata=self.get_metadata(add_sequence_key=False)) + par = PARSynthesizer(metadata=self.get_metadata()) par._sample = Mock() context_columns = pd.DataFrame({ 'gender': ['M', 'M', 'F']