Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jan 13, 2025
1 parent a213cfa commit c6d1942
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 11 deletions.
132 changes: 121 additions & 11 deletions tests/unit/evaluation/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def test_run_diagnostic_metadata():


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_continuous_data(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_continuous_data(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with continuous data.
Test that when we call ``get_column_plot`` with continuous data (datetime or numerical)
Expand All @@ -101,6 +102,7 @@ def test_get_column_plot_continuous_data(mock_get_plot):
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='numerical')
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col')
Expand All @@ -111,7 +113,8 @@ def test_get_column_plot_continuous_data(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_continuous_data_metadata(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_continuous_data_metadata(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with continuous data.
Test that when we call ``get_column_plot`` with continuous data (datetime or numerical)
Expand All @@ -122,6 +125,7 @@ def test_get_column_plot_continuous_data_metadata(mock_get_plot):
data2 = pd.DataFrame({'col': [2, 1, 3]})
metadata_dict = {'columns': {'col': {'sdtype': 'numerical'}}}
metadata = Metadata.load_from_dict(metadata_dict)
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col')
Expand All @@ -132,7 +136,8 @@ def test_get_column_plot_continuous_data_metadata(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_discrete_data(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_discrete_data(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with discrete data.
Test that when we call ``get_column_plot`` with discrete data (categorical or boolean)
Expand All @@ -144,6 +149,7 @@ def test_get_column_plot_discrete_data(mock_get_plot):
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='categorical')
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col')
Expand All @@ -154,7 +160,8 @@ def test_get_column_plot_discrete_data(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_discrete_data_metadata(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_discrete_data_metadata(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with discrete data.
Test that when we call ``get_column_plot`` with discrete data (categorical or boolean)
Expand All @@ -165,6 +172,7 @@ def test_get_column_plot_discrete_data_metadata(mock_get_plot):
data2 = pd.DataFrame({'col': ['a', 'b', 'c']})
metadata_dict = {'columns': {'col': {'sdtype': 'categorical'}}}
metadata = Metadata.load_from_dict(metadata_dict)
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col')
Expand All @@ -175,7 +183,8 @@ def test_get_column_plot_discrete_data_metadata(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_discrete_data_with_distplot(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_discrete_data_with_distplot(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with discrete data.
Test that when we call ``get_column_plot`` with discrete data (categorical or boolean)
Expand All @@ -188,6 +197,7 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot):
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='categorical')
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot')
Expand All @@ -198,7 +208,8 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_discrete_data_with_distplot_metadata(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_discrete_data_with_distplot_metadata(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with discrete data.
Test that when we call ``get_column_plot`` with discrete data (categorical or boolean)
Expand All @@ -210,6 +221,7 @@ def test_get_column_plot_discrete_data_with_distplot_metadata(mock_get_plot):
data2 = pd.DataFrame({'col': ['a', 'b', 'c']})
metadata_dict = {'columns': {'col': {'sdtype': 'categorical'}}}
metadata = Metadata.load_from_dict(metadata_dict)
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot')
Expand All @@ -220,7 +232,8 @@ def test_get_column_plot_discrete_data_with_distplot_metadata(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_invalid_sdtype(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_invalid_sdtype(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with sdtype that can't be plotted.
Test that when we call ``get_column_plot`` with an sdtype that can't be plotted, this raises
Expand All @@ -232,6 +245,7 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot):
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='id')
mock_prepare.side_effect = [data1, data2]

# Run and Assert
error_msg = re.escape(
Expand All @@ -243,7 +257,8 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_invalid_sdtype_metadata(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_invalid_sdtype_metadata(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with sdtype that can't be plotted.
Test that when we call ``get_column_plot`` with an sdtype that can't be plotted, this raises
Expand All @@ -254,6 +269,7 @@ def test_get_column_plot_invalid_sdtype_metadata(mock_get_plot):
data2 = pd.DataFrame({'col': ['a', 'b', 'c']})
metadata_dict = {'columns': {'col': {'sdtype': 'id'}}}
metadata = Metadata.load_from_dict(metadata_dict)
mock_prepare.side_effect = [data1, data2]

# Run and Assert
error_msg = re.escape(
Expand All @@ -265,7 +281,8 @@ def test_get_column_plot_invalid_sdtype_metadata(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_invalid_sdtype_with_plot_type(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with sdtype that can't be plotted.
Test that when we call ``get_column_plot`` with an sdtype that can't be plotted, but passing
Expand All @@ -277,6 +294,7 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot):
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='id')
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar')
Expand All @@ -287,7 +305,8 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot):


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_invalid_sdtype_with_plot_type_metadata(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_plot_invalid_sdtype_with_plot_type_metadata(mock_prepare, mock_get_plot):
"""Test the ``get_column_plot`` with sdtype that can't be plotted.
Test that when we call ``get_column_plot`` with an sdtype that can't be plotted, but passing
Expand All @@ -298,6 +317,7 @@ def test_get_column_plot_invalid_sdtype_with_plot_type_metadata(mock_get_plot):
data2 = pd.DataFrame({'col': ['a', 'b', 'c']})
metadata_dict = {'columns': {'col': {'sdtype': 'id'}}}
metadata = Metadata.load_from_dict(metadata_dict)
mock_prepare.side_effect = [data1, data2]

# Run
plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar')
Expand All @@ -307,6 +327,42 @@ def test_get_column_plot_invalid_sdtype_with_plot_type_metadata(mock_get_plot):
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_real_data_none(mock_get_plot):
"""Test ``get_column_plot`` when ``real_data`` is None."""
# Setup
data = pd.DataFrame({'col': [1, 2, 3]})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='numerical')

# Run
plot = get_column_plot(None, data, metadata, 'col')

# Assert
mock_get_plot.call_args[0][1].equals(data)
assert mock_get_plot.call_args[0][0] is None
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_synthetic_data_none(mock_get_plot):
"""Test ``get_column_plot`` when ``synthetic_data`` is None."""
# Setup
data = pd.DataFrame({'col': [1, 2, 3]})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('col', 'table', sdtype='numerical')

# Run
plot = get_column_plot(data, None, metadata, 'col')

# Assert
mock_get_plot.call_args[0][0].equals(data)
assert mock_get_plot.call_args[0][1] is None
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_plot')
def test_get_column_plot_with_datetime_sdtype(mock_get_plot):
"""Test the ``get_column_plot`` with datetime sdtype.
Expand Down Expand Up @@ -377,7 +433,8 @@ def test_get_column_pair_plot_with_continous_data(mock_get_plot):


@patch('sdmetrics.visualization.get_column_pair_plot')
def test_get_column_pair_plot_with_discrete_data(mock_get_plot):
@patch('sdv.evaluation.single_table._prepare_data_vizualisation')
def test_get_column_pair_plot_with_discrete_data(mock_prepare, mock_get_plot):
"""Test the ``get_column_pair_plot`` when using discrete data.
Test that the ``get_column_pair_plot`` will automatically use ``heatmap`` if the data
Expand All @@ -391,6 +448,7 @@ def test_get_column_pair_plot_with_discrete_data(mock_get_plot):
metadata.add_table('table')
metadata.add_column('name', 'table', sdtype='categorical')
metadata.add_column('subscriber', 'table', sdtype='boolean')
mock_prepare.side_effect = [real_data, synthetic_data]

# Run
plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns)
Expand Down Expand Up @@ -626,3 +684,55 @@ def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot):
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'scatter'
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_pair_plot')
def test__get_column_pair_plot_with_real_data_none(mock_get_plot):
"""Test ``get_column_pair_plot`` when ``real_data`` is None."""
# Setup
columns = ['amount', 'price']
real_data = None
synthetic_data = pd.DataFrame({
'amount': [1.0, 2.0, 3.0],
'price': [11.0, 22.0, 33.0],
})
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('amount', 'table', sdtype='numerical')
metadata.add_column('price', 'table', sdtype='numerical')

# Run
plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns)

# Assert
assert mock_get_plot.call_args[0][0] is None
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][1], synthetic_data)
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'scatter'
assert plot == mock_get_plot.return_value


@patch('sdmetrics.visualization.get_column_pair_plot')
def test__get_column_pair_plot_with_synthetic_data_none(mock_get_plot):
"""Test ``get_column_pair_plot`` when ``synthetic_data`` is None."""
# Setup
columns = ['amount', 'price']
real_data = pd.DataFrame({
'amount': [1, 2, 3],
'price': [10, 20, 30],
})
synthetic_data = None
metadata = Metadata()
metadata.add_table('table')
metadata.add_column('amount', 'table', sdtype='numerical')
metadata.add_column('price', 'table', sdtype='numerical')

# Run
plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns)

# Assert
pd.testing.assert_frame_equal(mock_get_plot.call_args[0][0], real_data)
assert mock_get_plot.call_args[0][1] is None
assert mock_get_plot.call_args[0][2] == columns
assert mock_get_plot.call_args[0][3] == 'scatter'
assert plot == mock_get_plot.return_value
32 changes: 32 additions & 0 deletions tests/unit/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_get_root_tables,
_is_datetime_type,
_is_numerical,
_prepare_data_vizualisation,
_validate_foreign_keys_not_null,
check_sdv_versions_and_warn,
check_synthesizer_version,
Expand Down Expand Up @@ -788,3 +789,34 @@ def test__is_numerical_string():
# Assert
assert str_result is False
assert datetime_result is False


def test__prepare_data_vizualisation():
"""Test ``_prepare_data_vizualisation``."""
# Setup
np.random.seed(0)
metadata = SingleTableMetadata.load_from_dict({
'columns': {
'col1': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'},
'col2': {'sdtype': 'numerical'},
}
})
column_names = ['col1', 'col2']
sample_size = 2
data = pd.DataFrame({
'col1': ['2021-01-01', '2021-02-01', '2021-03-01'],
'col2': [4, 5, 6],
})

# Run
result = _prepare_data_vizualisation(data, metadata, column_names, sample_size)

# Assert
expected_result = pd.DataFrame(
{
'col1': pd.to_datetime(['2021-03-01', '2021-02-01']),
'col2': [6, 5],
},
index=[2, 1],
)
pd.testing.assert_frame_equal(result, expected_result)

0 comments on commit c6d1942

Please sign in to comment.