diff --git a/Orange/widgets/evaluate/owpredictions.py b/Orange/widgets/evaluate/owpredictions.py
index 0fd3a5085af..3d57d9c5727 100644
--- a/Orange/widgets/evaluate/owpredictions.py
+++ b/Orange/widgets/evaluate/owpredictions.py
@@ -56,7 +56,12 @@ class Warning(OWWidget.Warning):
empty_data = Msg("Empty data set")
class Error(OWWidget.Error):
- predictor_failed = Msg("One or more predictors failed (see more...)\n{}")
+ predictor_failed = \
+ Msg("One or more predictors failed (see more...)\n{}")
+ predictors_target_mismatch = \
+ Msg("Predictors do not have the same target.")
+ data_target_mismatch = \
+ Msg("Data does not have the same target as predictors.")
settingsHandler = settings.ClassValuesContextHandler()
#: Display the full input dataset or only the target variable columns (if
@@ -182,13 +187,11 @@ def set_data(self, data):
self.data = data
if data is None:
- self.class_var = class_var = None
self.dataview.setModel(None)
self.predictionsview.setModel(None)
self.predictionsview.setItemDelegate(PredictionsItemDelegate())
else:
# force full reset of the view's HeaderView state
- self.class_var = class_var = data.domain.class_var
self.dataview.setModel(None)
model = TableModel(data, parent=None)
modelproxy = TableSortProxyModel()
@@ -196,18 +199,6 @@ def set_data(self, data):
self.dataview.setModel(modelproxy)
self._update_column_visibility()
- discrete_class = class_var is not None and class_var.is_discrete
- self.classification_options.setVisible(discrete_class)
-
- self.closeContext()
- if discrete_class:
- self.class_values = list(class_var.values)
- self.selected_classes = list(range(len(self.class_values)))
- self.openContext(self.class_var)
- else:
- self.class_values = []
- self.selected_classes = []
-
self._invalidate_predictions()
def set_predictor(self, predictor=None, id=None):
@@ -221,7 +212,36 @@ def set_predictor(self, predictor=None, id=None):
self.predictors[id] = \
PredictorSlot(predictor, predictor.name, None)
+ def set_class_var(self):
+ pred_classes = set(pred.predictor.domain.class_var
+ for pred in self.predictors.values())
+ self.Error.predictors_target_mismatch.clear()
+ self.Error.data_target_mismatch.clear()
+ self.class_var = None
+ if len(pred_classes) > 1:
+ self.Error.predictors_target_mismatch()
+ if len(pred_classes) == 1:
+ self.class_var = pred_classes.pop()
+ if self.data is not None and \
+ self.data.domain.class_var is not None and \
+ self.class_var != self.data.domain.class_var:
+ self.Error.data_target_mismatch()
+ self.class_var = None
+
+ discrete_class = self.class_var is not None \
+ and self.class_var.is_discrete
+ self.classification_options.setVisible(discrete_class)
+ self.closeContext()
+ if discrete_class:
+ self.class_values = list(self.class_var.values)
+ self.selected_classes = list(range(len(self.class_values)))
+ self.openContext(self.class_var)
+ else:
+ self.class_values = []
+ self.selected_classes = []
+
def handleNewSignals(self):
+ self.set_class_var()
if self.data is not None:
self._call_predictors()
self._update_predictions_model()
@@ -232,14 +252,9 @@ def handleNewSignals(self):
def _call_predictors(self):
for inputid, pred in self.predictors.items():
- if pred.results is None:
+ if pred.results is None or numpy.isnan(pred.results[0]).all():
try:
- predictor_class = pred.predictor.domain.class_var
- if predictor_class != self.class_var:
- results = "{}: mismatching target ({})".format(
- pred.predictor.name, predictor_class.name)
- else:
- results = self.predict(pred.predictor, self.data)
+ results = self.predict(pred.predictor, self.data)
except ValueError as err:
results = "{}: {}".format(pred.predictor.name, err)
self.predictors[inputid] = pred._replace(results=results)
@@ -285,12 +300,16 @@ def _invalidate_predictions(self):
self.predictors[inputid] = pred._replace(results=None)
def _valid_predictors(self):
- return [p for p in self.predictors.values()
- if p.results is not None and not isinstance(p.results, str)]
+ if self.class_var is not None and \
+ self.data is not None:
+ return [p for p in self.predictors.values()
+ if p.results is not None and not isinstance(p.results, str)]
+ else:
+ return []
def _update_predictions_model(self):
"""Update the prediction view model."""
- if self.data is not None:
+ if self.data is not None and self.class_var is not None:
slots = self._valid_predictors()
results = []
class_var = self.class_var
@@ -323,7 +342,7 @@ def _update_predictions_model(self):
def _update_column_visibility(self):
"""Update data column visibility."""
- if self.data is not None:
+ if self.data is not None and self.class_var is not None:
domain = self.data.domain
first_attr = len(domain.class_vars) + len(domain.metas)
@@ -415,12 +434,12 @@ def commit(self):
self._commit_evaluation_results()
def _commit_evaluation_results(self):
- class_var = self.class_var
slots = self._valid_predictors()
- if not slots:
+ if not slots or self.data.domain.class_var is None:
self.send("Evaluation Results", None)
return
+ class_var = self.class_var
nanmask = numpy.isnan(self.data.get_column_view(class_var)[0])
data = self.data[~nanmask]
N = len(data)
@@ -442,15 +461,15 @@ def _commit_predictions(self):
self.send("Predictions", None)
return
- class_var = self.class_var
- if class_var and class_var.is_discrete:
+ if self.class_var and self.class_var.is_discrete:
newmetas, newcolumns = self._classification_output_columns()
else:
newmetas, newcolumns = self._regression_output_columns()
attrs = list(self.data.domain.attributes) if self.output_attrs else []
metas = list(self.data.domain.metas) + newmetas
- domain = Orange.data.Domain(attrs, class_var, metas=metas)
+ domain = \
+ Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas)
predictions = self.data.from_table(domain, self.data)
if newcolumns:
newcolumns = numpy.hstack(
@@ -506,7 +525,7 @@ def merge_data_with_predictions():
[data_model.data(data_model.index(i, j))
for j in iter_data_cols]
- if self.data is not None:
+ if self.data is not None and self.class_var is not None:
text = self.infolabel.text().replace('\n', '
')
if self.show_probabilities and self.selected_classes:
text += '
Showing probabilities for: '
diff --git a/Orange/widgets/evaluate/tests/test_owpredictions.py b/Orange/widgets/evaluate/tests/test_owpredictions.py
index 8f366b7d2cd..bbae03bf3d8 100644
--- a/Orange/widgets/evaluate/tests/test_owpredictions.py
+++ b/Orange/widgets/evaluate/tests/test_owpredictions.py
@@ -5,7 +5,7 @@
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.evaluate.owpredictions import OWPredictions
-from Orange.data import Table
+from Orange.data import Table, Domain
from Orange.classification import MajorityLearner
from Orange.evaluation import Results
@@ -47,6 +47,8 @@ def test_nan_target_input(self):
self.assertEqual(len(evres.data), 0)
def test_mismatching_targets(self):
+ error = self.widget.Error
+
titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)
@@ -54,35 +56,66 @@ def test_mismatching_targets(self):
self.send_signal("Data", self.iris)
self.send_signal("Predictors", majority_iris, 1)
self.send_signal("Predictors", majority_titanic, 2)
- self.assertTrue(self.widget.Error.predictor_failed.is_shown())
- output = self.get_output("Predictions")
- self.assertEqual(len(output.domain.metas), 4)
+ self.assertTrue(error.predictors_target_mismatch.is_shown())
+ self.assertIsNone(self.get_output("Predictions"))
self.send_signal("Predictors", None, 1)
- self.assertTrue(self.widget.Error.predictor_failed.is_shown())
+ self.assertFalse(error.predictors_target_mismatch.is_shown())
+ self.assertTrue(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))
self.send_signal("Data", None)
- self.assertFalse(self.widget.Error.predictor_failed.is_shown())
+ self.assertFalse(error.predictors_target_mismatch.is_shown())
+ self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))
self.send_signal("Predictors", None, 2)
- self.assertFalse(self.widget.Error.predictor_failed.is_shown())
+ self.assertFalse(error.predictors_target_mismatch.is_shown())
+ self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))
self.send_signal("Predictors", majority_titanic, 2)
- self.assertFalse(self.widget.Error.predictor_failed.is_shown())
+ self.assertFalse(error.predictors_target_mismatch.is_shown())
+ self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))
self.send_signal("Data", self.iris)
- self.assertTrue(self.widget.Error.predictor_failed.is_shown())
+ self.assertFalse(error.predictors_target_mismatch.is_shown())
+ self.assertTrue(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))
self.send_signal("Predictors", majority_iris, 2)
- self.assertFalse(self.widget.Error.predictor_failed.is_shown())
+ self.assertFalse(error.predictors_target_mismatch.is_shown())
+ self.assertFalse(error.data_target_mismatch.is_shown())
+ output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 4)
self.send_signal("Predictors", majority_iris, 1)
self.send_signal("Predictors", majority_titanic, 3)
- output = self.get_output("Predictions")
- self.assertEqual(len(output.domain.metas), 8)
+ self.assertTrue(error.predictors_target_mismatch.is_shown())
+ self.assertFalse(error.data_target_mismatch.is_shown())
+ self.assertIsNone(self.get_output("Predictions"))
+
+ def test_no_class_on_test(self):
+ """Allow test data with no class"""
+ error = self.widget.Error
+
+ titanic = Table("titanic")
+ majority_titanic = MajorityLearner()(titanic)
+ majority_iris = MajorityLearner()(self.iris)
+
+ no_class = Table(Domain(titanic.domain.attributes, None), titanic)
+ self.send_signal("Predictors", majority_titanic, 1)
+ self.send_signal("Data", no_class)
+ out = self.get_output("Predictions")
+ np.testing.assert_allclose(out.get_column_view("majority")[0], 0)
+
+ self.send_signal("Predictors", majority_iris, 2)
+ self.assertTrue(error.predictors_target_mismatch.is_shown())
+ self.assertFalse(error.data_target_mismatch.is_shown())
+ self.assertIsNone(self.get_output("Predictions"))
+
+ self.send_signal("Predictors", None, 2)
+ self.send_signal("Data", titanic)
+ out = self.get_output("Predictions")
+ np.testing.assert_allclose(out.get_column_view("majority")[0], 0)