Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include stack trace when sampling errors are surfaced #2329

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdv/single_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def handle_sampling_error(output_file_path, sampling_error):
)

if error_msg:
raise type(sampling_error)(error_msg + '\n' + str(sampling_error))
raise type(sampling_error)(error_msg) from sampling_error

raise sampling_error

Expand Down
6 changes: 4 additions & 2 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,15 +1546,17 @@ def test__sample_with_progress_bar_without_output_filepath(self):
instance._fitted = True
expected_message = re.escape(
'Error: Sampling terminated. No results were saved due to unspecified '
'"output_file_path".\nMocked Error'
'"output_file_path".'
)
instance._sample_in_batches.side_effect = RuntimeError('Mocked Error')

# Run and Assert
with pytest.raises(RuntimeError, match=expected_message):
with pytest.raises(RuntimeError, match=expected_message) as exception:
BaseSingleTableSynthesizer._sample_with_progress_bar(
instance, output_file_path=None, num_rows=10
)
assert isinstance(exception.value.__cause__, RuntimeError)
assert 'Mocked Error' in str(exception.value.__cause__)

@patch('sdv.single_table.base.datetime')
def test_sample(self, mock_datetime, caplog):
Expand Down
12 changes: 9 additions & 3 deletions tests/unit/single_table/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,13 @@ def test_unflatten_dict():
def test_handle_sampling_error_temp_file():
"""Test that an error is raised when temp dir is ``False``."""
# Run and Assert
error_msg = 'Error: Sampling terminated. Partial results are stored in test.csv.\nTest error'
with pytest.raises(ValueError, match=error_msg):
error_msg = 'Error: Sampling terminated. Partial results are stored in test.csv.'
with pytest.raises(ValueError, match=error_msg) as exception:
handle_sampling_error('test.csv', ValueError('Test error'))

assert isinstance(exception.value.__cause__, ValueError)
assert 'Test error' in str(exception.value.__cause__)


def test_handle_sampling_error_false_temp_file_none_output_file():
"""Test the ``handle_sampling_error`` function.
Expand All @@ -228,9 +231,12 @@ def test_handle_sampling_error_false_temp_file_none_output_file():
"""
# Run and Assert
error_msg = 'Test error'
with pytest.raises(ValueError, match=error_msg):
with pytest.raises(ValueError) as exception:
handle_sampling_error('test.csv', ValueError('Test error'))

assert isinstance(exception.value.__cause__, ValueError)
assert error_msg in str(exception.value.__cause__)


def test_handle_sampling_error_ignore():
"""Test that the error is raised if the error is the no rows error."""
Expand Down
Loading