Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Nov 8, 2023
1 parent 6ba3c8a commit e78df2e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/unit/evaluation/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ def test_get_column_pair_plot(mock_plot):
mock_plot.return_value = 'plot'

# Run
plot = get_column_pair_plot(data1, data2, metadata, 'table', ['col1', 'col2'])
plot = get_column_pair_plot(data1, data2, metadata, 'table', ['col1', 'col2'], 2)

# Assert
call_metadata = metadata.tables['table']
mock_plot.assert_called_once_with(table1, table2, call_metadata, ['col1', 'col2'], None)
mock_plot.assert_called_once_with(table1, table2, call_metadata, ['col1', 'col2'], 2, None)
assert plot == 'plot'


Expand Down
72 changes: 72 additions & 0 deletions tests/unit/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,75 @@ def test_get_column_pair_plot_with_invalid_sdtype_and_plot_type(mock_get_plot):
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'heatmap'
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_pair_plot')
def test_get_column_pair_plot_with_sample_size(mock_get_plot):
"""Test ``get_column_pair_plot`` with sample_size parameter."""
# Setup
columns = ['amount', 'date']
real_data = pd.DataFrame({
'amount': [1, 2, 3],
'date': ['2021-01-01', '2022-01-01', '2023-01-01'],
})
synthetic_data = pd.DataFrame({
'amount': [1., 2., 3.],
'date': ['2021-01-01', '2022-01-01', '2023-01-01'],
})
metadata = SingleTableMetadata()
metadata.add_column('amount', sdtype='numerical')
metadata.add_column('date', sdtype='datetime')

# Run
plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2)

# Assert
expected_real_data = pd.DataFrame({
'amount': [1, 2],
'date': pd.to_datetime(['2021-01-01', '2022-01-01']),
}, index=[0, 1])
expected_synth_data = pd.DataFrame({
'amount': [3., 1.],
'date': pd.to_datetime(['2023-01-01', '2021-01-01']),
}, index=[2, 0])
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][0], expected_real_data)
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][1], expected_synth_data)
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'scatter'
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_pair_plot')
def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot):
"""Test ``get_column_pair_plot`` when sample_size is bigger than the lenght of the data."""
# Setup
columns = ['amount', 'date']
real_data = pd.DataFrame({
'amount': [1, 2, 3],
'date': ['2021-01-01', '2022-01-01', '2023-01-01'],
})
synthetic_data = pd.DataFrame({
'amount': [1., 2., 3.],
'date': ['2021-01-01', '2022-01-01', '2023-01-01'],
})
metadata = SingleTableMetadata()
metadata.add_column('amount', sdtype='numerical')
metadata.add_column('date', sdtype='datetime')

# Run
plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=10)

# Assert
expected_real_data = pd.DataFrame({
'amount': [1, 2, 3],
'date': pd.to_datetime(['2021-01-01', '2022-01-01', '2023-01-01']),
})
expected_synth_data = pd.DataFrame({
'amount': [1., 2., 3.],
'date': pd.to_datetime(['2021-01-01', '2022-01-01', '2023-01-01']),
})
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][0], expected_real_data)
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][1], expected_synth_data)
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'scatter'
assert plot == mock_get_plot.return_value

0 comments on commit e78df2e

Please sign in to comment.