From 9d8b70f8429863685cbeef1965d6fe259c88a53d Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 13 Dec 2019 16:47:03 +0100 Subject: [PATCH] Test and Score: Add comparison of models --- Orange/widgets/evaluate/owtestlearners.py | 139 +++++++++++++- .../evaluate/tests/test_owtestlearners.py | 174 +++++++++++++++++- requirements-core.txt | 1 + 3 files changed, 311 insertions(+), 3 deletions(-) diff --git a/Orange/widgets/evaluate/owtestlearners.py b/Orange/widgets/evaluate/owtestlearners.py index d82796b4894..399813b61ef 100644 --- a/Orange/widgets/evaluate/owtestlearners.py +++ b/Orange/widgets/evaluate/owtestlearners.py @@ -1,5 +1,6 @@ # pylint doesn't understand the Settings magic # pylint: disable=invalid-sequence-index +# pylint: disable=too-many-lines,too-many-instance-attributes import abc import enum import logging @@ -9,14 +10,17 @@ from concurrent.futures import Future from collections import OrderedDict, namedtuple +from itertools import count from typing import Any, Optional, List, Dict, Callable import numpy as np +import baycomp from AnyQt import QtGui -from AnyQt.QtGui import QStandardItem from AnyQt.QtCore import Qt, QSize, QThread from AnyQt.QtCore import pyqtSlot as Slot +from AnyQt.QtGui import QStandardItem, QDoubleValidator +from AnyQt.QtWidgets import QHeaderView, QTableWidget, QLabel from Orange.base import Learner import Orange.classification @@ -35,7 +39,7 @@ from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.utils.concurrent import ThreadExecutor, TaskState from Orange.widgets.widget import OWWidget, Msg, Input, Output - +from orangewidget.utils.itemmodels import PyListModel log = logging.getLogger(__name__) @@ -175,6 +179,10 @@ class Outputs: fold_feature = settings.ContextSetting(None) fold_feature_selected = settings.ContextSetting(False) + use_rope = settings.Setting(False) + rope = settings.Setting(0.1) + comparison_criterion = settings.Setting(0) + TARGET_AVERAGE = "(Average over classes)" class_selection = settings.ContextSetting(TARGET_AVERAGE) @@ -275,13 +283,46 @@ def __init__(self): callback=self._on_target_class_changed, contentsLength=8) + self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison") + gui.comboBox( + box, self, "comparison_criterion", model=PyListModel(), + callback=self.update_comparison_table) + + hbox = gui.hBox(box) + gui.checkBox(hbox, self, "use_rope", + "Negligible difference: ", + callback=self.update_comparison_table) + gui.lineEdit(hbox, self, "rope", validator=QDoubleValidator(), + controlWidth=70, callback=self.update_comparison_table, + alignment=Qt.AlignRight) + gui.rubber(self.controlArea) self.score_table = ScoreTable(self) self.score_table.shownScoresChanged.connect(self.update_stats_model) + view = self.score_table.view + view.setSizeAdjustPolicy(view.AdjustToContents) box = gui.vBox(self.mainArea, "Evaluation Results") box.layout().addWidget(self.score_table.view) + self.compbox = box = gui.vBox(self.mainArea, box="Model comparison") + table = self.comparison_table = QTableWidget( + wordWrap=False, editTriggers=QTableWidget.NoEditTriggers, + selectionMode=QTableWidget.NoSelection) + table.setSizeAdjustPolicy(table.AdjustToContents) + table.verticalHeader().setSectionResizeMode(QHeaderView.Fixed) + + header = table.horizontalHeader() + header.setSectionResizeMode(QHeaderView.ResizeToContents) + header.setDefaultAlignment(Qt.AlignLeft) + header.setStretchLastSection(False) + box.layout().addWidget(table) + box.layout().addWidget(QLabel( + "Table shows probabilities that the score for the model in " + "the row is higher than that of the model in the column. " + "Small numbers show the probability that the difference is " + "negligible.", wordWrap=True)) + @staticmethod def sizeHint(): return QSize(780, 1) @@ -440,6 +481,8 @@ def _update_scorers(self): self.scorers = [] return self.scorers = usable_scorers(self.data.domain.class_var) + self.controls.comparison_criterion.model()[:] = \ + [scorer.long_name or scorer.name for scorer in self.scorers] @Inputs.preprocessor def set_preprocessor(self, preproc): @@ -470,6 +513,9 @@ def shuffle_split_changed(self): self._param_changed() def _param_changed(self): + is_kfold = self.resampling == OWTestLearners.KFold + self.modcompbox.setEnabled(is_kfold) + self.comparison_table.setEnabled(is_kfold) self._invalidate() self.__update() @@ -562,6 +608,91 @@ def update_stats_model(self): self.error("\n".join(errors), shown=bool(errors)) self.Warning.scores_not_computed(shown=has_missing_scores) + def update_comparison_table(self): + self.comparison_table.clearContents() + if self.resampling != OWTestLearners.KFold: + return + + slots = self._successful_slots() + scores = self._scores_by_folds(slots) + self._fill_table(slots, scores) + + def _successful_slots(self): + model = self.score_table.model + proxy = self.score_table.sorted_model + + keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole) + for row in range(proxy.rowCount())) + slots = [slot for slot in (self.learners[key] for key in keys) + if slot.results is not None and slot.results.success] + return slots + + def _scores_by_folds(self, slots): + scorer = self.scorers[self.comparison_criterion]() + self.compbox.setTitle(f"Model comparison by {scorer.name}") + if scorer.is_binary: + if self.class_selection != self.TARGET_AVERAGE: + class_var = self.data.domain.class_var + target_index = class_var.values.index(self.class_selection) + kw = dict(target=target_index) + else: + kw = dict(average='weighted') + else: + kw = {} + + def call_scorer(results): + def thunked(): + return scorer.scores_by_folds(results.value, **kw).flatten() + + return thunked + + scores = [Try(call_scorer(slot.results)) for slot in slots] + scores = [score.value if score.success else None for score in scores] + # `None in scores doesn't work -- these are np.arrays) + if any(score is None for score in scores): + self.Warning.scores_not_computed() + return scores + + def _fill_table(self, slots, scores): + table = self.comparison_table + table.setRowCount(len(slots)) + table.setColumnCount(len(slots)) + + names = [learner_name(slot.learner) for slot in slots] + table.setVerticalHeaderLabels(names) + table.setHorizontalHeaderLabels(names) + + for row, row_name, row_scores in zip(count(), names, scores): + for col, col_name, col_scores in zip(range(row), names, scores): + if row_scores is None or col_scores is None: + continue + if self.use_rope and self.rope: + p0, rope, p1 = baycomp.two_on_single( + row_scores, col_scores, self.rope) + self._set_cell(table, row, col, + f"{p0:.3f}
{rope:.3f})", + f"p({row_name} > {col_name}) = {p0:.3f}\n" + f"p({row_name} = {col_name}) = {rope:.3f}") + self._set_cell(table, col, row, + f"{p1:.3f}
{rope:.3f}", + f"p({col_name} > {row_name}) = {p1:.3f}\n" + f"p({col_name} = {row_name}) = {rope:.3f}") + else: + p0, p1 = baycomp.two_on_single(row_scores, col_scores) + self._set_cell(table, row, col, + f"{p0:.3f}", + f"p({row_name} > {col_name}) = {p0:.3f}") + self._set_cell(table, col, row, + f"{p1:.3f}", + f"p({col_name} > {row_name}) = {p1:.3f}") + + @staticmethod + def _set_cell(table, row, col, label, tooltip): + item = QLabel(label) + item.setToolTip(tooltip) + item.setAlignment(Qt.AlignCenter) + table.setCellWidget(row, col, item) + def _update_class_selection(self): self.class_selection_combo.setCurrentIndex(-1) self.class_selection_combo.clear() @@ -585,6 +716,7 @@ def _update_class_selection(self): def _on_target_class_changed(self): self.update_stats_model() + self.update_comparison_table() def _invalidate(self, which=None): self.cancel() @@ -611,6 +743,8 @@ def _invalidate(self, which=None): item.setData(None, Qt.DisplayRole) item.setData(None, Qt.ToolTipRole) + self.comparison_table.clearContents() + self.__needupdate = True def commit(self): @@ -866,6 +1000,7 @@ def __task_complete(self, f: 'Future[Results]'): self.score_table.update_header(self.scorers) self.update_stats_model() + self.update_comparison_table() self.commit() diff --git a/Orange/widgets/evaluate/tests/test_owtestlearners.py b/Orange/widgets/evaluate/tests/test_owtestlearners.py index 4d356d542ac..8602a946fc0 100644 --- a/Orange/widgets/evaluate/tests/test_owtestlearners.py +++ b/Orange/widgets/evaluate/tests/test_owtestlearners.py @@ -1,18 +1,21 @@ # pylint: disable=missing-docstring # pylint: disable=protected-access import unittest +from unittest.mock import Mock, patch import warnings import numpy as np from AnyQt.QtCore import Qt from AnyQt.QtTest import QTest +import baycomp from Orange.classification import MajorityLearner, LogisticRegressionLearner from Orange.classification.majority import ConstantModel from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable -from Orange.evaluation import Results, TestOnTestData +from Orange.evaluation import Results, TestOnTestData, scoring from Orange.evaluation.scoring import ClassificationScore, RegressionScore, \ Score +from Orange.base import Learner from Orange.modelling import ConstantLearner from Orange.regression import MeanLearner from Orange.widgets.evaluate.owtestlearners import ( @@ -25,6 +28,11 @@ from Orange.tests import test_filename +class BadLearner(Learner): + def fit(self, *_, **_2): # pylint: disable=arguments-differ + return 1 / 0 + + class TestOWTestLearners(WidgetTest): def setUp(self): super().setUp() @@ -391,6 +399,170 @@ def test_no_pregressbar_warning(self): self.send_signal(self.widget.Inputs.learner, MajorityLearner(), 0) assert not w + def _set_comparison_score(self, score): + w = self.widget + control = w.controls.comparison_criterion + control.setCurrentText(score) + w.comparison_criterion = control.model().indexOf(score) + + def _set_three_majorities(self): + w = self.widget + data = Table("iris")[::15] + self.send_signal(w.Inputs.train_data, data) + for i, name in enumerate(["maja", "majb", "majc"]): + learner = MajorityLearner() + learner.name = name + self.send_signal(w.Inputs.learner, learner, i) + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + + @patch("baycomp.two_on_single", Mock(wraps=baycomp.two_on_single)) + def test_comparison_requires_cv(self): + w = self.widget + w.comparison_criterion = 1 + rbs = w.controls.resampling.buttons + + self._set_three_majorities() + baycomp.two_on_single.reset_mock() + + rbs[OWTestLearners.KFold].click() + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + self.assertIsNotNone(w.comparison_table.cellWidget(0, 1)) + self.assertTrue(w.modcompbox.isEnabled()) + self.assertTrue(w.comparison_table.isEnabled()) + baycomp.two_on_single.assert_called() + baycomp.two_on_single.reset_mock() + + rbs[OWTestLearners.LeaveOneOut].click() + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + self.assertIsNone(w.comparison_table.cellWidget(0, 1)) + self.assertFalse(w.modcompbox.isEnabled()) + self.assertFalse(w.comparison_table.isEnabled()) + baycomp.two_on_single.assert_not_called() + baycomp.two_on_single.reset_mock() + + rbs[OWTestLearners.KFold].click() + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + self.assertIsNotNone(w.comparison_table.cellWidget(0, 1)) + self.assertTrue(w.modcompbox.isEnabled()) + self.assertTrue(w.comparison_table.isEnabled()) + baycomp.two_on_single.assert_called() + baycomp.two_on_single.reset_mock() + + @patch("baycomp.two_on_single", Mock(wraps=baycomp.two_on_single)) + def test_comparison_bad_slots(self): + w = self.widget + self._set_three_majorities() + self._set_comparison_score("Classification accuracy") + self.send_signal(w.Inputs.learner, BadLearner(), 2, wait=5000) + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + self.assertIsNotNone(w.comparison_table.cellWidget(0, 1)) + self.assertIsNone(w.comparison_table.cellWidget(0, 2)) + self.assertEqual(len(w._successful_slots()), 2) + + def test_comparison_bad_scores(self): + w = self.widget + self._set_three_majorities() + self._set_comparison_score("Classification accuracy") + self.get_output(self.widget.Outputs.evaluations_results, wait=5000) + + score_calls = -1 + + def fail_on_first(*_, **_2): + nonlocal score_calls + score_calls += 1 + return 1 / score_calls + + with patch.object(scoring.CA, "compute_score", new=fail_on_first): + w.update_comparison_table() + + self.assertIsNone(w.comparison_table.cellWidget(0, 1)) + self.assertIsNone(w.comparison_table.cellWidget(0, 2)) + self.assertIsNone(w.comparison_table.cellWidget(1, 0)) + self.assertIsNone(w.comparison_table.cellWidget(2, 0)) + self.assertIsNotNone(w.comparison_table.cellWidget(1, 2)) + self.assertIsNotNone(w.comparison_table.cellWidget(2, 1)) + self.assertTrue(w.Warning.scores_not_computed.is_shown()) + + score_calls = -1 + with patch.object(scoring.CA, "compute_score", new=fail_on_first): + slots = w._successful_slots() + self.assertEqual(len(slots), 3) + scores = w._scores_by_folds(slots) + self.assertIsNone(scores[0]) + self.assertEqual(scores[1][0], 1) + self.assertAlmostEqual(scores[2][0], 1 / 11) + + def test_comparison_binary_score(self): + # false warning at call_arg.kwargs + # pylint: disable=unpacking-non-sequence + w = self.widget + self._set_three_majorities() + self._set_comparison_score("F1") + f1mock = Mock(wraps=scoring.F1) + + iris = Table("iris") + with patch.object(scoring.F1, "compute_score", f1mock): + simulate.combobox_activate_item(w.controls.class_selection, + iris.domain.class_var.values[1]) + _, kwargs = f1mock.call_args + self.assertEqual(kwargs["target"], 1) + self.assertFalse("average" in kwargs) + + simulate.combobox_activate_item(w.controls.class_selection, + iris.domain.class_var.values[2]) + _, kwargs = f1mock.call_args + self.assertEqual(kwargs["target"], 2) + self.assertFalse("average" in kwargs) + + simulate.combobox_activate_item(w.controls.class_selection, + OWTestLearners.TARGET_AVERAGE) + _, kwargs = f1mock.call_args + self.assertEqual(kwargs["average"], "weighted") + self.assertFalse("target" in kwargs) + + def test_fill_table(self): + w = self.widget + self._set_three_majorities() + scores = [object(), object(), object()] + slots = w._successful_slots() + + def probs(p1, p2, rope): + p1 += 1 + p2 += 1 + norm = p1 + p2 + rope * (p1 + p2) + if rope == 0: + return p1 / norm, p2 / norm + else: + return p1 / norm, rope / norm, p2 / norm + + def two_on_single(res1, res2, rope=0): + return probs(scores.index(res1), scores.index(res2), rope) + + with patch("baycomp.two_on_single", new=two_on_single): + for w.use_rope, w.rope in ((True, 0), (False, 0.1)): + w._fill_table(slots, scores) + for row in range(3): + for col in range(3): + if row == col: + continue + label = w.comparison_table.cellWidget(row, col) + self.assertEqual(label.text(), + f"{(row + 1) / (row + col + 2):.3f}") + self.assertIn(f"{(row + 1) / (row + col + 2):.3f}", + label.toolTip()) + + w.use_rope = True + w.rope = 0.25 + w._fill_table(slots, scores) + for row in range(3): + for col in range(3): + if row == col: + continue + label = w.comparison_table.cellWidget(row, col) + for text in (label.text(), label.toolTip()): + self.assertIn(f"{probs(row, col, w.rope)[0]:.3f}", text) + self.assertIn(f"{probs(row, col, w.rope)[1]:.3f}", text) + class TestHelpers(unittest.TestCase): def test_results_one_vs_rest(self): diff --git a/requirements-core.txt b/requirements-core.txt index 1de8e80cd12..ffcf2d9cca8 100644 --- a/requirements-core.txt +++ b/requirements-core.txt @@ -19,6 +19,7 @@ networkx python-louvain>=0.13 requests openTSNE>=0.3.11 +baycomp>=1.0.2 pandas pyyaml openpyxl