Skip to content

Commit

Permalink
Add msg
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed May 16, 2024
1 parent c2a1a6a commit 7749f05
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
9 changes: 8 additions & 1 deletion sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,14 @@ def load(cls, filepath):
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
synthesizer = cloudpickle.load(f)
try:
synthesizer = cloudpickle.load(f)
except RuntimeError:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
9 changes: 8 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,14 @@ def load(cls, filepath):
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
synthesizer = cloudpickle.load(f)
try:
synthesizer = cloudpickle.load(f)
except RuntimeError:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,3 +1603,19 @@ def test_load(self, mock_file, cloudpickle_mock,
'SYNTHESIZER CLASS NAME': 'Mock',
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

@patch('builtins.open')
@patch('sdv.multi_table.base.cloudpickle')
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError

# Run and Assert
err_msg = re.escape(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
with pytest.raises(SamplingError, match=err_msg):
BaseMultiTableSynthesizer.load('synth.pkl')
16 changes: 16 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,22 @@ def test_load_custom_constraint_classes(self):
['Custom', 'Constr', 'UpperPlus']
)

@patch('builtins.open')
@patch('sdv.single_table.base.cloudpickle')
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError

# Run and Assert
err_msg = re.escape(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
with pytest.raises(SamplingError, match=err_msg):
BaseSingleTableSynthesizer.load('synth.pkl')

def test_add_custom_constraint_class(self):
"""Test that this method calls the ``DataProcessor``'s method."""
# Setup
Expand Down

0 comments on commit 7749f05

Please sign in to comment.