diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index 13495d11..4596ec33 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -3133,29 +3133,32 @@ def _multiclass_confusion_matrix_alert_thresholds( return alert_thresholds def _multi_class_confusion_matrix_realized_performance(self, data: pd.DataFrame) -> Union[np.ndarray, float]: + # Create appropriate nan array to return in case of error + num_classes = len(self.classes) + nan_array = np.full(shape=(num_classes, num_classes), fill_value=np.nan) try: _list_missing([self.y_true, self.y_pred], data) except InvalidArgumentsException as ex: if "missing required columns" in str(ex): self._logger.debug(str(ex)) - return np.NaN + return nan_array else: raise ex data, empty = common_nan_removal(data, [self.y_true, self.y_pred]) if empty: warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") - return np.NaN + return nan_array y_true = data[self.y_true] if y_true.nunique() <= 1: warnings.warn(f"Too few unique values present in 'y_true', returning NaN as realized {self.display_name}.") - return np.NaN + return nan_array if data[self.y_pred].nunique() <= 1: warnings.warn( f"Too few unique values present in 'y_pred', returning NaN as realized {self.display_name} score." ) - return np.NaN + return nan_array cm = confusion_matrix( data[self.y_true], data[self.y_pred], labels=self.classes, normalize=self.normalize_confusion_matrix diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index ebc30229..c2ae06cb 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -1,6 +1,7 @@ """Tests.""" import pandas as pd +import numpy as np import pytest from nannyml.chunk import DefaultChunker, SizeBasedChunker @@ -3462,3 +3463,120 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits f'{metric.display_name} lower threshold value -1 overridden by ' f'lower threshold value limit {metric.lower_threshold_value_limit}' in caplog.messages ) + + +@pytest.mark.parametrize( + 'calculator_opts, realized', + [ + ( + {'chunk_size': 20000}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'], + 'realized_roc_auc': [0.909805, 0.840071, np.nan], + 'realized_f1': [0.759170, 0.658896, np.nan], + 'realized_precision': [0.759265, 0.660188, np.nan], + 'realized_recall': [0.759149, 0.658760, np.nan], + 'realized_specificity': [0.879632, 0.829581, np.nan], + 'realized_accuracy': [0.75925, 0.65950, np.nan], + 'realized_true_highstreet_card_pred_highstreet_card': [ + 4912.0, + 4702.0, + np.nan, + ], + 'realized_true_highstreet_card_pred_prepaid_card': [ + 870.0, + 1083.0, + np.nan, + ], + 'realized_true_highstreet_card_pred_upmarket_card': [ + 799.0, + 1009.0, + np.nan, + ], + 'realized_true_prepaid_card_pred_highstreet_card': [ + 846.0, + 1367.0, + np.nan, + ], + 'realized_true_prepaid_card_pred_prepaid_card': [ + 5203.0, + 3974.0, + np.nan, + ], + 'realized_true_prepaid_card_pred_upmarket_card': [ + 690.0, + 1080.0, + np.nan, + ], + 'realized_true_upmarket_card_pred_highstreet_card': [ + 837.0, + 1282.0, + np.nan, + ], + 'realized_true_upmarket_card_pred_prepaid_card': [ + 773.0, + 989.0, + np.nan, + ], + 'realized_true_upmarket_card_pred_upmarket_card': [ + 5070.0, + 4514.0, + np.nan, + ], + } + ), + ), + ] +) +def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realized): # noqa: D103 + """Test Nan Handling of CM MC metric.""" + reference, analysis, targets = load_synthetic_multiclass_classification_dataset() + analysis = analysis.merge(targets, left_index=True, right_index=True) + analysis.y_true[-20_000:] = np.nan + cbpe = CBPE( + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + y_pred='y_pred', + y_true='y_true', + problem_type='classification_multiclass', + metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'], + **calculator_opts, + ).fit(reference) + result = cbpe.estimate(analysis) + column_names = [(m.name, 'realized') for m in result.metrics] + column_names = [c for c in column_names if c[0] != 'confusion_matrix'] + column_names += [ + ('true_highstreet_card_pred_highstreet_card', 'realized'), + ('true_highstreet_card_pred_prepaid_card', 'realized'), + ('true_highstreet_card_pred_upmarket_card', 'realized'), + ('true_prepaid_card_pred_highstreet_card', 'realized'), + ('true_prepaid_card_pred_prepaid_card', 'realized'), + ('true_prepaid_card_pred_upmarket_card', 'realized'), + ('true_upmarket_card_pred_highstreet_card', 'realized'), + ('true_upmarket_card_pred_prepaid_card', 'realized'), + ('true_upmarket_card_pred_upmarket_card', 'realized'), + ] + sut = result.filter(period='analysis').to_df()[[('chunk', 'key')] + column_names] + sut.columns = [ + 'key', + 'realized_roc_auc', + 'realized_f1', + 'realized_precision', + 'realized_recall', + 'realized_specificity', + 'realized_accuracy', + 'realized_true_highstreet_card_pred_highstreet_card', + 'realized_true_highstreet_card_pred_prepaid_card', + 'realized_true_highstreet_card_pred_upmarket_card', + 'realized_true_prepaid_card_pred_highstreet_card', + 'realized_true_prepaid_card_pred_prepaid_card', + 'realized_true_prepaid_card_pred_upmarket_card', + 'realized_true_upmarket_card_pred_highstreet_card', + 'realized_true_upmarket_card_pred_prepaid_card', + 'realized_true_upmarket_card_pred_upmarket_card', + ] + pd.testing.assert_frame_equal(realized, sut)