diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 8db57467e..50a38c316 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -709,6 +709,7 @@ def load(cls, filepath): 'machine is CPU-only. This feature is currently unsupported. We recommend' ' sampling on the same GPU-enabled machine.' ) + raise e check_synthesizer_version(synthesizer) check_sdv_versions_and_warn(synthesizer) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index b670d5695..2fc8c18d0 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -507,6 +507,7 @@ def load(cls, filepath): 'machine is CPU-only. This feature is currently unsupported. We recommend' ' sampling on the same GPU-enabled machine.' ) + raise e check_synthesizer_version(synthesizer) check_sdv_versions_and_warn(synthesizer) diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index aa9eab456..6880de0ef 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -1624,3 +1624,14 @@ def test_load_runtime_error(self, cloudpickle_mock, mock_open): ) with pytest.raises(SamplingError, match=err_msg): BaseMultiTableSynthesizer.load('synth.pkl') + + @patch('builtins.open') + @patch('sdv.multi_table.base.cloudpickle') + def test_load_runtime_error_no_change(self, cloudpickle_mock, mock_open): + """Test that the synthesizer's load method errors with the correct message.""" + # Setup + cloudpickle_mock.load.side_effect = RuntimeError('Error') + + # Run and Assert + with pytest.raises(RuntimeError, match='Error'): + BaseMultiTableSynthesizer.load('synth.pkl') diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index caf9170e5..1ae03b541 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1935,6 +1935,17 @@ def test_load_runtime_error(self, cloudpickle_mock, mock_open): with pytest.raises(SamplingError, match=err_msg): BaseSingleTableSynthesizer.load('synth.pkl') + @patch('builtins.open') + @patch('sdv.single_table.base.cloudpickle') + def test_load_runtime_error_no_change(self, cloudpickle_mock, mock_open): + """Test that the synthesizer's load method errors with the correct message.""" + # Setup + cloudpickle_mock.load.side_effect = RuntimeError('Error') + + # Run and Assert + with pytest.raises(RuntimeError, match='Error'): + BaseSingleTableSynthesizer.load('synth.pkl') + def test_add_custom_constraint_class(self): """Test that this method calls the ``DataProcessor``'s method.""" # Setup