Skip to content

Commit

Permalink
fix: prts version specification until CompML/PRTS#77 is merged and ad…
Browse files Browse the repository at this point in the history
…d skiptests conditions if not installed
  • Loading branch information
SebastianSchmidl committed Aug 13, 2024
1 parent e8849fa commit d500a6f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 29 deletions.
12 changes: 3 additions & 9 deletions aeon/performance_metrics/anomaly_detection/_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def range_precision(
1920–30. 2018.
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
"""
_check_soft_dependencies(
"prts>=1.0.0.3", obj="range_precision", suppress_import_stdout=True
)
_check_soft_dependencies("prts", obj="range_precision", suppress_import_stdout=True)

from prts import ts_precision

Expand Down Expand Up @@ -117,9 +115,7 @@ def range_recall(
1920–30. 2018.
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
"""
_check_soft_dependencies(
"prts>=1.0.0.3", obj="range_recall", suppress_import_stdout=True
)
_check_soft_dependencies("prts", obj="range_recall", suppress_import_stdout=True)

from prts import ts_recall

Expand Down Expand Up @@ -187,9 +183,7 @@ def range_f_score(
1920–30. 2018.
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
"""
_check_soft_dependencies(
"prts>=1.0.0.3", obj="range_recall", suppress_import_stdout=True
)
_check_soft_dependencies("prts", obj="range_recall", suppress_import_stdout=True)

from prts import ts_fscore

Expand Down
4 changes: 2 additions & 2 deletions aeon/performance_metrics/anomaly_detection/_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def f_score_at_k_ranges(
Function used to find the threshold.
"""
_check_soft_dependencies(
"prts>=1.0.0.3", obj="f_score_at_k_ranges", suppress_import_stdout=True
"prts", obj="f_score_at_k_ranges", suppress_import_stdout=True
)

from prts import ts_fscore
Expand Down Expand Up @@ -231,7 +231,7 @@ def rp_rr_auc_score(
http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf.
"""
_check_soft_dependencies(
"prts>=1.0.0.3", obj="f_score_at_k_ranges", suppress_import_stdout=True
"prts", obj="f_score_at_k_ranges", suppress_import_stdout=True
)

from prts import ts_precision, ts_recall
Expand Down
74 changes: 56 additions & 18 deletions aeon/performance_metrics/anomaly_detection/tests/test_ad_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,46 @@
rp_rr_auc_score,
)
from aeon.testing.data_generation import make_example_1d_numpy
from aeon.utils.validation._dependencies import _check_soft_dependencies

pr_metrics = [pr_auc_score, rp_rr_auc_score]
pr_metrics = [pr_auc_score]
range_metrics = [
range_roc_auc_score,
range_pr_auc_score,
range_roc_vus_score,
range_pr_vus_score,
range_recall,
range_precision,
range_f_score,
]
other_metrics = [
roc_auc_score,
f_score_at_k_points,
f_score_at_k_ranges,
rp_rr_auc_score,
]
continuous_metrics = [*pr_metrics, *other_metrics, *range_metrics]
binary_metrics = []

if _check_soft_dependencies("prts", severity="none"):
pr_metrics.append(rp_rr_auc_score)
range_metrics.extend(
[
range_recall,
range_precision,
range_f_score,
]
)
other_metrics.extend(
[
f_score_at_k_ranges,
rp_rr_auc_score,
]
)
continuous_metrics.extend(
[
rp_rr_auc_score,
f_score_at_k_ranges,
]
)
binary_metrics = [range_recall, range_precision, range_f_score]

metrics = [*pr_metrics, *range_metrics, *other_metrics]
continuous_metrics = [
*pr_metrics,
range_roc_auc_score,
range_pr_auc_score,
range_roc_vus_score,
range_pr_vus_score,
roc_auc_score,
rp_rr_auc_score,
f_score_at_k_points,
f_score_at_k_ranges,
]
binary_metrics = [range_recall, range_precision, range_f_score]


@pytest.mark.parametrize("metric", metrics, ids=[m.__name__ for m in metrics])
Expand Down Expand Up @@ -140,6 +150,10 @@ def test_edge_cases_pr_metrics(metric):
assert score <= 0.2, f"{metric.__name__}(y_true, y_inverted)={score} is not <= 0.2"


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_range_based_f1():
"""Test range-based F1 score."""
y_pred = np.array([0, 1, 1, 0])
Expand All @@ -148,6 +162,10 @@ def test_range_based_f1():
np.testing.assert_almost_equal(result, 0.66666, decimal=4)


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_range_based_precision():
"""Test range-based precision."""
y_pred = np.array([0, 1, 1, 0])
Expand All @@ -156,6 +174,10 @@ def test_range_based_precision():
assert result == 0.5


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_range_based_recall():
"""Test range-based recall."""
y_pred = np.array([0, 1, 1, 0])
Expand All @@ -164,6 +186,10 @@ def test_range_based_recall():
assert result == 1


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_rf1_value_error():
"""Test range-based F1 score raises ValueError on binary predictions."""
y_pred = np.array([0, 0.2, 0.7, 0])
Expand All @@ -187,6 +213,10 @@ def test_pr_curve_auc():
# np.testing.assert_almost_equal(result, 0.8333, decimal=4)


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_range_based_p_range_based_r_curve_auc():
"""Test range-based precision-recall curve AUC."""
y_pred = np.array([0, 0.1, 1.0, 0.5, 0.1, 0])
Expand All @@ -195,6 +225,10 @@ def test_range_based_p_range_based_r_curve_auc():
np.testing.assert_almost_equal(result, 0.9792, decimal=4)


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_range_based_p_range_based_r_auc_perfect_hit():
"""Test range-based precision-recall curve AUC with perfect hit."""
y_pred = np.array([0, 0, 0.5, 0.5, 0, 0])
Expand All @@ -203,6 +237,10 @@ def test_range_based_p_range_based_r_auc_perfect_hit():
np.testing.assert_almost_equal(result, 1.0000, decimal=4)


@pytest.mark.skipif(
not _check_soft_dependencies("prts", severity="none"),
reason="required soft dependency prts not available",
)
def test_f_score_at_k_ranges():
"""Test range-based F1 score at k ranges."""
y_pred = np.array([0.4, 0.1, 1.0, 0.5, 0.1, 0, 0.4, 0.5])
Expand Down

0 comments on commit d500a6f

Please sign in to comment.