-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Test & Score: Add comparison of models #4261
Changes from all commits
0779836
df4ed84
e1fad41
66bef42
a27cce6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# pylint doesn't understand the Settings magic | ||
# pylint: disable=invalid-sequence-index | ||
# pylint: disable=too-many-lines,too-many-instance-attributes | ||
import abc | ||
import enum | ||
import logging | ||
|
@@ -9,14 +10,17 @@ | |
|
||
from concurrent.futures import Future | ||
from collections import OrderedDict, namedtuple | ||
from itertools import count | ||
from typing import Any, Optional, List, Dict, Callable | ||
|
||
import numpy as np | ||
import baycomp | ||
|
||
from AnyQt import QtGui | ||
from AnyQt.QtGui import QStandardItem | ||
from AnyQt.QtCore import Qt, QSize, QThread | ||
from AnyQt.QtCore import pyqtSlot as Slot | ||
from AnyQt.QtGui import QStandardItem, QDoubleValidator | ||
from AnyQt.QtWidgets import QHeaderView, QTableWidget, QLabel | ||
|
||
from Orange.base import Learner | ||
import Orange.classification | ||
|
@@ -35,7 +39,7 @@ | |
from Orange.widgets.utils.widgetpreview import WidgetPreview | ||
from Orange.widgets.utils.concurrent import ThreadExecutor, TaskState | ||
from Orange.widgets.widget import OWWidget, Msg, Input, Output | ||
|
||
from orangewidget.utils.itemmodels import PyListModel | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
@@ -175,6 +179,10 @@ class Outputs: | |
fold_feature = settings.ContextSetting(None) | ||
fold_feature_selected = settings.ContextSetting(False) | ||
|
||
use_rope = settings.Setting(False) | ||
rope = settings.Setting(0.1) | ||
comparison_criterion = settings.Setting(0, schema_only=True) | ||
|
||
TARGET_AVERAGE = "(Average over classes)" | ||
class_selection = settings.ContextSetting(TARGET_AVERAGE) | ||
|
||
|
@@ -216,6 +224,7 @@ def __init__(self): | |
self.train_data_missing_vals = False | ||
self.test_data_missing_vals = False | ||
self.scorers = [] | ||
self.__pending_comparison_criterion = self.comparison_criterion | ||
|
||
#: An Ordered dictionary with current inputs and their testing results. | ||
self.learners = OrderedDict() # type: Dict[Any, Input] | ||
|
@@ -275,13 +284,55 @@ def __init__(self): | |
callback=self._on_target_class_changed, | ||
contentsLength=8) | ||
|
||
self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison") | ||
gui.comboBox( | ||
box, self, "comparison_criterion", model=PyListModel(), | ||
callback=self.update_comparison_table) | ||
|
||
hbox = gui.hBox(box) | ||
gui.checkBox(hbox, self, "use_rope", | ||
"Negligible difference: ", | ||
callback=self._on_use_rope_changed) | ||
gui.lineEdit(hbox, self, "rope", validator=QDoubleValidator(), | ||
controlWidth=70, callback=self.update_comparison_table, | ||
alignment=Qt.AlignRight) | ||
self.controls.rope.setEnabled(self.use_rope) | ||
|
||
gui.rubber(self.controlArea) | ||
self.score_table = ScoreTable(self) | ||
self.score_table.shownScoresChanged.connect(self.update_stats_model) | ||
view = self.score_table.view | ||
view.setSizeAdjustPolicy(view.AdjustToContents) | ||
|
||
box = gui.vBox(self.mainArea, "Evaluation Results") | ||
box.layout().addWidget(self.score_table.view) | ||
|
||
self.compbox = box = gui.vBox(self.mainArea, box="Model comparison") | ||
table = self.comparison_table = QTableWidget( | ||
wordWrap=False, editTriggers=QTableWidget.NoEditTriggers, | ||
selectionMode=QTableWidget.NoSelection) | ||
table.setSizeAdjustPolicy(table.AdjustToContents) | ||
header = table.verticalHeader() | ||
header.setSectionResizeMode(QHeaderView.Fixed) | ||
header.setSectionsClickable(False) | ||
|
||
header = table.horizontalHeader() | ||
header.setTextElideMode(Qt.ElideRight) | ||
header.setDefaultAlignment(Qt.AlignCenter) | ||
header.setSectionsClickable(False) | ||
header.setStretchLastSection(False) | ||
header.setSectionResizeMode(QHeaderView.ResizeToContents) | ||
avg_width = self.fontMetrics().averageCharWidth() | ||
header.setMinimumSectionSize(8 * avg_width) | ||
header.setMaximumSectionSize(15 * avg_width) | ||
header.setDefaultSectionSize(15 * avg_width) | ||
box.layout().addWidget(table) | ||
box.layout().addWidget(QLabel( | ||
"<small>Table shows probabilities that the score for the model in " | ||
"the row is higher than that of the model in the column. " | ||
"Small numbers show the probability that the difference is " | ||
"negligible.</small>", wordWrap=True)) | ||
|
||
@staticmethod | ||
def sizeHint(): | ||
return QSize(780, 1) | ||
|
@@ -436,10 +487,32 @@ def _which_missing_data(self): | |
# - we don't gain much with it | ||
# - it complicates the unit tests | ||
def _update_scorers(self): | ||
if self.data is None or self.data.domain.class_var is None: | ||
self.scorers = [] | ||
return | ||
self.scorers = usable_scorers(self.data.domain.class_var) | ||
if self.data and self.data.domain.class_var: | ||
new_scorers = usable_scorers(self.data.domain.class_var) | ||
else: | ||
new_scorers = [] | ||
# Don't unnecessarily reset the model because this would always reset | ||
# comparison_criterion; we alse set it explicitly, though, for clarity | ||
if new_scorers != self.scorers: | ||
self.scorers = new_scorers | ||
self.controls.comparison_criterion.model()[:] = \ | ||
[scorer.long_name or scorer.name for scorer in self.scorers] | ||
self.comparison_criterion = 0 | ||
if self.__pending_comparison_criterion is not None: | ||
# Check for the unlikely case that some scorers have been removed | ||
# from modules | ||
if self.__pending_comparison_criterion < len(self.scorers): | ||
self.comparison_criterion = self.__pending_comparison_criterion | ||
self.__pending_comparison_criterion = None | ||
self._update_compbox_title() | ||
|
||
def _update_compbox_title(self): | ||
criterion = self.comparison_criterion | ||
if criterion < len(self.scorers): | ||
scorer = self.scorers[criterion]() | ||
self.compbox.setTitle(f"Model Comparison by {scorer.name}") | ||
else: | ||
self.compbox.setTitle(f"Model Comparison") | ||
|
||
@Inputs.preprocessor | ||
def set_preprocessor(self, preproc): | ||
|
@@ -453,6 +526,7 @@ def handleNewSignals(self): | |
"""Reimplemented from OWWidget.handleNewSignals.""" | ||
self._update_class_selection() | ||
self.score_table.update_header(self.scorers) | ||
self._update_view_enabled() | ||
self.update_stats_model() | ||
if self.__needupdate: | ||
self.__update() | ||
|
@@ -470,9 +544,19 @@ def shuffle_split_changed(self): | |
self._param_changed() | ||
|
||
def _param_changed(self): | ||
self.modcompbox.setEnabled(self.resampling == OWTestLearners.KFold) | ||
self._update_view_enabled() | ||
self._invalidate() | ||
self.__update() | ||
|
||
def _update_view_enabled(self): | ||
self.comparison_table.setEnabled( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why disabling the table when nothing can be clicked anyway? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like disabling it because it shows the user that it's intentionally blank. Otherwise it looks like a bug when the upper table is filled and this one isn't (e.g. when using Leave one out). Hiding would also be an option, though I like disabling better -- like "something could be here, but currently isn't because I can't compute it in this situation". I can disable it when there is no data. But in this case we should do the same with the above table, I suppose. We need to discuss this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both views now disable under same conditions. |
||
self.resampling == OWTestLearners.KFold | ||
and len(self.learners) > 1 | ||
and self.data is not None) | ||
self.score_table.view.setEnabled( | ||
self.data is not None) | ||
|
||
def update_stats_model(self): | ||
# Update the results_model with up to date scores. | ||
# Note: The target class specific scores (if requested) are | ||
|
@@ -494,8 +578,10 @@ def update_stats_model(self): | |
errors = [] | ||
has_missing_scores = False | ||
|
||
names = [] | ||
for key, slot in self.learners.items(): | ||
name = learner_name(slot.learner) | ||
names.append(name) | ||
head = QStandardItem(name) | ||
head.setData(key, Qt.UserRole) | ||
results = slot.results | ||
|
@@ -558,10 +644,123 @@ def update_stats_model(self): | |
header.sortIndicatorSection(), | ||
header.sortIndicatorOrder() | ||
) | ||
self._set_comparison_headers(names) | ||
|
||
self.error("\n".join(errors), shown=bool(errors)) | ||
self.Warning.scores_not_computed(shown=has_missing_scores) | ||
|
||
def _on_use_rope_changed(self): | ||
self.controls.rope.setEnabled(self.use_rope) | ||
self.update_comparison_table() | ||
|
||
def update_comparison_table(self): | ||
self.comparison_table.clearContents() | ||
slots = self._successful_slots() | ||
if not (slots and self.scorers): | ||
return | ||
names = [learner_name(slot.learner) for slot in slots] | ||
self._set_comparison_headers(names) | ||
if self.resampling == OWTestLearners.KFold: | ||
scores = self._scores_by_folds(slots) | ||
self._fill_table(names, scores) | ||
|
||
def _successful_slots(self): | ||
model = self.score_table.model | ||
proxy = self.score_table.sorted_model | ||
|
||
keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole) | ||
for row in range(proxy.rowCount())) | ||
slots = [slot for slot in (self.learners[key] for key in keys) | ||
if slot.results is not None and slot.results.success] | ||
return slots | ||
|
||
def _set_comparison_headers(self, names): | ||
table = self.comparison_table | ||
try: | ||
# Prevent glitching during update | ||
table.setUpdatesEnabled(False) | ||
header = table.horizontalHeader() | ||
if len(names) > 2: | ||
header.setSectionResizeMode(QHeaderView.Stretch) | ||
else: | ||
header.setSectionResizeMode(QHeaderView.Fixed) | ||
table.setRowCount(len(names)) | ||
table.setColumnCount(len(names)) | ||
table.setVerticalHeaderLabels(names) | ||
table.setHorizontalHeaderLabels(names) | ||
finally: | ||
table.setUpdatesEnabled(True) | ||
|
||
def _scores_by_folds(self, slots): | ||
scorer = self.scorers[self.comparison_criterion]() | ||
VesnaT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._update_compbox_title() | ||
if scorer.is_binary: | ||
if self.class_selection != self.TARGET_AVERAGE: | ||
class_var = self.data.domain.class_var | ||
target_index = class_var.values.index(self.class_selection) | ||
kw = dict(target=target_index) | ||
else: | ||
kw = dict(average='weighted') | ||
else: | ||
kw = {} | ||
|
||
def call_scorer(results): | ||
def thunked(): | ||
return scorer.scores_by_folds(results.value, **kw).flatten() | ||
|
||
return thunked | ||
|
||
scores = [Try(call_scorer(slot.results)) for slot in slots] | ||
scores = [score.value if score.success else None for score in scores] | ||
# `None in scores doesn't work -- these are np.arrays) | ||
if any(score is None for score in scores): | ||
self.Warning.scores_not_computed() | ||
return scores | ||
|
||
def _fill_table(self, names, scores): | ||
table = self.comparison_table | ||
for row, row_name, row_scores in zip(count(), names, scores): | ||
for col, col_name, col_scores in zip(range(row), names, scores): | ||
if row_scores is None or col_scores is None: | ||
continue | ||
if self.use_rope and self.rope: | ||
p0, rope, p1 = baycomp.two_on_single( | ||
row_scores, col_scores, self.rope) | ||
if np.isnan(p0) or np.isnan(rope) or np.isnan(p1): | ||
self._set_cells_na(table, row, col) | ||
continue | ||
self._set_cell(table, row, col, | ||
f"{p0:.3f}<br/><small>{rope:.3f}</small>", | ||
f"p({row_name} > {col_name}) = {p0:.3f}\n" | ||
f"p({row_name} = {col_name}) = {rope:.3f}") | ||
self._set_cell(table, col, row, | ||
f"{p1:.3f}<br/><small>{rope:.3f}</small>", | ||
f"p({col_name} > {row_name}) = {p1:.3f}\n" | ||
f"p({col_name} = {row_name}) = {rope:.3f}") | ||
else: | ||
p0, p1 = baycomp.two_on_single(row_scores, col_scores) | ||
if np.isnan(p0) or np.isnan(p1): | ||
self._set_cells_na(table, row, col) | ||
continue | ||
self._set_cell(table, row, col, | ||
f"{p0:.3f}", | ||
f"p({row_name} > {col_name}) = {p0:.3f}") | ||
self._set_cell(table, col, row, | ||
f"{p1:.3f}", | ||
f"p({col_name} > {row_name}) = {p1:.3f}") | ||
|
||
@classmethod | ||
def _set_cells_na(cls, table, row, col): | ||
cls._set_cell(table, row, col, "NA", "comparison cannot be computed") | ||
cls._set_cell(table, col, row, "NA", "comparison cannot be computed") | ||
|
||
@staticmethod | ||
def _set_cell(table, row, col, label, tooltip): | ||
item = QLabel(label) | ||
item.setToolTip(tooltip) | ||
item.setAlignment(Qt.AlignCenter) | ||
table.setCellWidget(row, col, item) | ||
|
||
def _update_class_selection(self): | ||
self.class_selection_combo.setCurrentIndex(-1) | ||
self.class_selection_combo.clear() | ||
|
@@ -585,6 +784,7 @@ def _update_class_selection(self): | |
|
||
def _on_target_class_changed(self): | ||
self.update_stats_model() | ||
self.update_comparison_table() | ||
|
||
def _invalidate(self, which=None): | ||
self.cancel() | ||
|
@@ -611,6 +811,8 @@ def _invalidate(self, which=None): | |
item.setData(None, Qt.DisplayRole) | ||
item.setData(None, Qt.ToolTipRole) | ||
|
||
self.comparison_table.clearContents() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only clears the contents, but retains the headers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
self.__needupdate = True | ||
|
||
def commit(self): | ||
|
@@ -866,6 +1068,7 @@ def __task_complete(self, f: 'Future[Results]'): | |
|
||
self.score_table.update_header(self.scorers) | ||
self.update_stats_model() | ||
self.update_comparison_table() | ||
|
||
self.commit() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this not a spinbox?
It should probably disabled when
use_rope
is not checked.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not a spin box because it has no defined range. It can refer to AUC, that is, between 0 and 1, or it can be RMSE, which is between 0 and infinity -- it can easily be 100000.
Disabling it would make sense, I'll do that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added disabling but didn't like it. Let's let the user change the line edit first and then enable it, if (s)he wishes.
I added a method
_on_use_rope_changed
. You can add a lineself.controls.rope.setEnabled(self.use_rope)
and see for yourself that you won't like it. :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a won't fix. :)