From d500a6f142a9233a0d77f8e3988d33e278a40250 Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Tue, 13 Aug 2024 13:45:56 +0200 Subject: [PATCH] fix: prts version specification until https://github.com/CompML/PRTS/pull/77 is merged and add skiptests conditions if not installed --- .../anomaly_detection/_binary.py | 12 +-- .../anomaly_detection/_continuous.py | 4 +- .../tests/test_ad_metrics.py | 74 ++++++++++++++----- 3 files changed, 61 insertions(+), 29 deletions(-) diff --git a/aeon/performance_metrics/anomaly_detection/_binary.py b/aeon/performance_metrics/anomaly_detection/_binary.py index 17d01a23a6..bc167c53ce 100644 --- a/aeon/performance_metrics/anomaly_detection/_binary.py +++ b/aeon/performance_metrics/anomaly_detection/_binary.py @@ -55,9 +55,7 @@ def range_precision( 1920–30. 2018. http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf. """ - _check_soft_dependencies( - "prts>=1.0.0.3", obj="range_precision", suppress_import_stdout=True - ) + _check_soft_dependencies("prts", obj="range_precision", suppress_import_stdout=True) from prts import ts_precision @@ -117,9 +115,7 @@ def range_recall( 1920–30. 2018. http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf. """ - _check_soft_dependencies( - "prts>=1.0.0.3", obj="range_recall", suppress_import_stdout=True - ) + _check_soft_dependencies("prts", obj="range_recall", suppress_import_stdout=True) from prts import ts_recall @@ -187,9 +183,7 @@ def range_f_score( 1920–30. 2018. http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf. """ - _check_soft_dependencies( - "prts>=1.0.0.3", obj="range_recall", suppress_import_stdout=True - ) + _check_soft_dependencies("prts", obj="range_recall", suppress_import_stdout=True) from prts import ts_fscore diff --git a/aeon/performance_metrics/anomaly_detection/_continuous.py b/aeon/performance_metrics/anomaly_detection/_continuous.py index 7f24f013b3..6626782fed 100644 --- a/aeon/performance_metrics/anomaly_detection/_continuous.py +++ b/aeon/performance_metrics/anomaly_detection/_continuous.py @@ -157,7 +157,7 @@ def f_score_at_k_ranges( Function used to find the threshold. """ _check_soft_dependencies( - "prts>=1.0.0.3", obj="f_score_at_k_ranges", suppress_import_stdout=True + "prts", obj="f_score_at_k_ranges", suppress_import_stdout=True ) from prts import ts_fscore @@ -231,7 +231,7 @@ def rp_rr_auc_score( http://papers.nips.cc/paper/7462-precision-and-recall-for-time-series.pdf. """ _check_soft_dependencies( - "prts>=1.0.0.3", obj="f_score_at_k_ranges", suppress_import_stdout=True + "prts", obj="f_score_at_k_ranges", suppress_import_stdout=True ) from prts import ts_precision, ts_recall diff --git a/aeon/performance_metrics/anomaly_detection/tests/test_ad_metrics.py b/aeon/performance_metrics/anomaly_detection/tests/test_ad_metrics.py index e3a0659a7a..617e55027a 100644 --- a/aeon/performance_metrics/anomaly_detection/tests/test_ad_metrics.py +++ b/aeon/performance_metrics/anomaly_detection/tests/test_ad_metrics.py @@ -18,36 +18,46 @@ rp_rr_auc_score, ) from aeon.testing.data_generation import make_example_1d_numpy +from aeon.utils.validation._dependencies import _check_soft_dependencies -pr_metrics = [pr_auc_score, rp_rr_auc_score] +pr_metrics = [pr_auc_score] range_metrics = [ range_roc_auc_score, range_pr_auc_score, range_roc_vus_score, range_pr_vus_score, - range_recall, - range_precision, - range_f_score, ] other_metrics = [ roc_auc_score, f_score_at_k_points, - f_score_at_k_ranges, - rp_rr_auc_score, ] +continuous_metrics = [*pr_metrics, *other_metrics, *range_metrics] +binary_metrics = [] + +if _check_soft_dependencies("prts", severity="none"): + pr_metrics.append(rp_rr_auc_score) + range_metrics.extend( + [ + range_recall, + range_precision, + range_f_score, + ] + ) + other_metrics.extend( + [ + f_score_at_k_ranges, + rp_rr_auc_score, + ] + ) + continuous_metrics.extend( + [ + rp_rr_auc_score, + f_score_at_k_ranges, + ] + ) + binary_metrics = [range_recall, range_precision, range_f_score] + metrics = [*pr_metrics, *range_metrics, *other_metrics] -continuous_metrics = [ - *pr_metrics, - range_roc_auc_score, - range_pr_auc_score, - range_roc_vus_score, - range_pr_vus_score, - roc_auc_score, - rp_rr_auc_score, - f_score_at_k_points, - f_score_at_k_ranges, -] -binary_metrics = [range_recall, range_precision, range_f_score] @pytest.mark.parametrize("metric", metrics, ids=[m.__name__ for m in metrics]) @@ -140,6 +150,10 @@ def test_edge_cases_pr_metrics(metric): assert score <= 0.2, f"{metric.__name__}(y_true, y_inverted)={score} is not <= 0.2" +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_range_based_f1(): """Test range-based F1 score.""" y_pred = np.array([0, 1, 1, 0]) @@ -148,6 +162,10 @@ def test_range_based_f1(): np.testing.assert_almost_equal(result, 0.66666, decimal=4) +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_range_based_precision(): """Test range-based precision.""" y_pred = np.array([0, 1, 1, 0]) @@ -156,6 +174,10 @@ def test_range_based_precision(): assert result == 0.5 +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_range_based_recall(): """Test range-based recall.""" y_pred = np.array([0, 1, 1, 0]) @@ -164,6 +186,10 @@ def test_range_based_recall(): assert result == 1 +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_rf1_value_error(): """Test range-based F1 score raises ValueError on binary predictions.""" y_pred = np.array([0, 0.2, 0.7, 0]) @@ -187,6 +213,10 @@ def test_pr_curve_auc(): # np.testing.assert_almost_equal(result, 0.8333, decimal=4) +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_range_based_p_range_based_r_curve_auc(): """Test range-based precision-recall curve AUC.""" y_pred = np.array([0, 0.1, 1.0, 0.5, 0.1, 0]) @@ -195,6 +225,10 @@ def test_range_based_p_range_based_r_curve_auc(): np.testing.assert_almost_equal(result, 0.9792, decimal=4) +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_range_based_p_range_based_r_auc_perfect_hit(): """Test range-based precision-recall curve AUC with perfect hit.""" y_pred = np.array([0, 0, 0.5, 0.5, 0, 0]) @@ -203,6 +237,10 @@ def test_range_based_p_range_based_r_auc_perfect_hit(): np.testing.assert_almost_equal(result, 1.0000, decimal=4) +@pytest.mark.skipif( + not _check_soft_dependencies("prts", severity="none"), + reason="required soft dependency prts not available", +) def test_f_score_at_k_ranges(): """Test range-based F1 score at k ranges.""" y_pred = np.array([0.4, 0.1, 1.0, 0.5, 0.1, 0, 0.4, 0.5])