diff --git a/Orange/widgets/evaluate/owpredictions.py b/Orange/widgets/evaluate/owpredictions.py index 0fd3a5085af..3d57d9c5727 100644 --- a/Orange/widgets/evaluate/owpredictions.py +++ b/Orange/widgets/evaluate/owpredictions.py @@ -56,7 +56,12 @@ class Warning(OWWidget.Warning): empty_data = Msg("Empty data set") class Error(OWWidget.Error): - predictor_failed = Msg("One or more predictors failed (see more...)\n{}") + predictor_failed = \ + Msg("One or more predictors failed (see more...)\n{}") + predictors_target_mismatch = \ + Msg("Predictors do not have the same target.") + data_target_mismatch = \ + Msg("Data does not have the same target as predictors.") settingsHandler = settings.ClassValuesContextHandler() #: Display the full input dataset or only the target variable columns (if @@ -182,13 +187,11 @@ def set_data(self, data): self.data = data if data is None: - self.class_var = class_var = None self.dataview.setModel(None) self.predictionsview.setModel(None) self.predictionsview.setItemDelegate(PredictionsItemDelegate()) else: # force full reset of the view's HeaderView state - self.class_var = class_var = data.domain.class_var self.dataview.setModel(None) model = TableModel(data, parent=None) modelproxy = TableSortProxyModel() @@ -196,18 +199,6 @@ def set_data(self, data): self.dataview.setModel(modelproxy) self._update_column_visibility() - discrete_class = class_var is not None and class_var.is_discrete - self.classification_options.setVisible(discrete_class) - - self.closeContext() - if discrete_class: - self.class_values = list(class_var.values) - self.selected_classes = list(range(len(self.class_values))) - self.openContext(self.class_var) - else: - self.class_values = [] - self.selected_classes = [] - self._invalidate_predictions() def set_predictor(self, predictor=None, id=None): @@ -221,7 +212,36 @@ def set_predictor(self, predictor=None, id=None): self.predictors[id] = \ PredictorSlot(predictor, predictor.name, None) + def set_class_var(self): + pred_classes = set(pred.predictor.domain.class_var + for pred in self.predictors.values()) + self.Error.predictors_target_mismatch.clear() + self.Error.data_target_mismatch.clear() + self.class_var = None + if len(pred_classes) > 1: + self.Error.predictors_target_mismatch() + if len(pred_classes) == 1: + self.class_var = pred_classes.pop() + if self.data is not None and \ + self.data.domain.class_var is not None and \ + self.class_var != self.data.domain.class_var: + self.Error.data_target_mismatch() + self.class_var = None + + discrete_class = self.class_var is not None \ + and self.class_var.is_discrete + self.classification_options.setVisible(discrete_class) + self.closeContext() + if discrete_class: + self.class_values = list(self.class_var.values) + self.selected_classes = list(range(len(self.class_values))) + self.openContext(self.class_var) + else: + self.class_values = [] + self.selected_classes = [] + def handleNewSignals(self): + self.set_class_var() if self.data is not None: self._call_predictors() self._update_predictions_model() @@ -232,14 +252,9 @@ def handleNewSignals(self): def _call_predictors(self): for inputid, pred in self.predictors.items(): - if pred.results is None: + if pred.results is None or numpy.isnan(pred.results[0]).all(): try: - predictor_class = pred.predictor.domain.class_var - if predictor_class != self.class_var: - results = "{}: mismatching target ({})".format( - pred.predictor.name, predictor_class.name) - else: - results = self.predict(pred.predictor, self.data) + results = self.predict(pred.predictor, self.data) except ValueError as err: results = "{}: {}".format(pred.predictor.name, err) self.predictors[inputid] = pred._replace(results=results) @@ -285,12 +300,16 @@ def _invalidate_predictions(self): self.predictors[inputid] = pred._replace(results=None) def _valid_predictors(self): - return [p for p in self.predictors.values() - if p.results is not None and not isinstance(p.results, str)] + if self.class_var is not None and \ + self.data is not None: + return [p for p in self.predictors.values() + if p.results is not None and not isinstance(p.results, str)] + else: + return [] def _update_predictions_model(self): """Update the prediction view model.""" - if self.data is not None: + if self.data is not None and self.class_var is not None: slots = self._valid_predictors() results = [] class_var = self.class_var @@ -323,7 +342,7 @@ def _update_predictions_model(self): def _update_column_visibility(self): """Update data column visibility.""" - if self.data is not None: + if self.data is not None and self.class_var is not None: domain = self.data.domain first_attr = len(domain.class_vars) + len(domain.metas) @@ -415,12 +434,12 @@ def commit(self): self._commit_evaluation_results() def _commit_evaluation_results(self): - class_var = self.class_var slots = self._valid_predictors() - if not slots: + if not slots or self.data.domain.class_var is None: self.send("Evaluation Results", None) return + class_var = self.class_var nanmask = numpy.isnan(self.data.get_column_view(class_var)[0]) data = self.data[~nanmask] N = len(data) @@ -442,15 +461,15 @@ def _commit_predictions(self): self.send("Predictions", None) return - class_var = self.class_var - if class_var and class_var.is_discrete: + if self.class_var and self.class_var.is_discrete: newmetas, newcolumns = self._classification_output_columns() else: newmetas, newcolumns = self._regression_output_columns() attrs = list(self.data.domain.attributes) if self.output_attrs else [] metas = list(self.data.domain.metas) + newmetas - domain = Orange.data.Domain(attrs, class_var, metas=metas) + domain = \ + Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas) predictions = self.data.from_table(domain, self.data) if newcolumns: newcolumns = numpy.hstack( @@ -506,7 +525,7 @@ def merge_data_with_predictions(): [data_model.data(data_model.index(i, j)) for j in iter_data_cols] - if self.data is not None: + if self.data is not None and self.class_var is not None: text = self.infolabel.text().replace('\n', '
') if self.show_probabilities and self.selected_classes: text += '
Showing probabilities for: ' diff --git a/Orange/widgets/evaluate/tests/test_owpredictions.py b/Orange/widgets/evaluate/tests/test_owpredictions.py index 8f366b7d2cd..bbae03bf3d8 100644 --- a/Orange/widgets/evaluate/tests/test_owpredictions.py +++ b/Orange/widgets/evaluate/tests/test_owpredictions.py @@ -5,7 +5,7 @@ from Orange.widgets.tests.base import WidgetTest from Orange.widgets.evaluate.owpredictions import OWPredictions -from Orange.data import Table +from Orange.data import Table, Domain from Orange.classification import MajorityLearner from Orange.evaluation import Results @@ -47,6 +47,8 @@ def test_nan_target_input(self): self.assertEqual(len(evres.data), 0) def test_mismatching_targets(self): + error = self.widget.Error + titanic = Table("titanic") majority_titanic = MajorityLearner()(titanic) majority_iris = MajorityLearner()(self.iris) @@ -54,35 +56,66 @@ def test_mismatching_targets(self): self.send_signal("Data", self.iris) self.send_signal("Predictors", majority_iris, 1) self.send_signal("Predictors", majority_titanic, 2) - self.assertTrue(self.widget.Error.predictor_failed.is_shown()) - output = self.get_output("Predictions") - self.assertEqual(len(output.domain.metas), 4) + self.assertTrue(error.predictors_target_mismatch.is_shown()) + self.assertIsNone(self.get_output("Predictions")) self.send_signal("Predictors", None, 1) - self.assertTrue(self.widget.Error.predictor_failed.is_shown()) + self.assertFalse(error.predictors_target_mismatch.is_shown()) + self.assertTrue(error.data_target_mismatch.is_shown()) self.assertIsNone(self.get_output("Predictions")) self.send_signal("Data", None) - self.assertFalse(self.widget.Error.predictor_failed.is_shown()) + self.assertFalse(error.predictors_target_mismatch.is_shown()) + self.assertFalse(error.data_target_mismatch.is_shown()) self.assertIsNone(self.get_output("Predictions")) self.send_signal("Predictors", None, 2) - self.assertFalse(self.widget.Error.predictor_failed.is_shown()) + self.assertFalse(error.predictors_target_mismatch.is_shown()) + self.assertFalse(error.data_target_mismatch.is_shown()) self.assertIsNone(self.get_output("Predictions")) self.send_signal("Predictors", majority_titanic, 2) - self.assertFalse(self.widget.Error.predictor_failed.is_shown()) + self.assertFalse(error.predictors_target_mismatch.is_shown()) + self.assertFalse(error.data_target_mismatch.is_shown()) self.assertIsNone(self.get_output("Predictions")) self.send_signal("Data", self.iris) - self.assertTrue(self.widget.Error.predictor_failed.is_shown()) + self.assertFalse(error.predictors_target_mismatch.is_shown()) + self.assertTrue(error.data_target_mismatch.is_shown()) self.assertIsNone(self.get_output("Predictions")) self.send_signal("Predictors", majority_iris, 2) - self.assertFalse(self.widget.Error.predictor_failed.is_shown()) + self.assertFalse(error.predictors_target_mismatch.is_shown()) + self.assertFalse(error.data_target_mismatch.is_shown()) + output = self.get_output("Predictions") self.assertEqual(len(output.domain.metas), 4) self.send_signal("Predictors", majority_iris, 1) self.send_signal("Predictors", majority_titanic, 3) - output = self.get_output("Predictions") - self.assertEqual(len(output.domain.metas), 8) + self.assertTrue(error.predictors_target_mismatch.is_shown()) + self.assertFalse(error.data_target_mismatch.is_shown()) + self.assertIsNone(self.get_output("Predictions")) + + def test_no_class_on_test(self): + """Allow test data with no class""" + error = self.widget.Error + + titanic = Table("titanic") + majority_titanic = MajorityLearner()(titanic) + majority_iris = MajorityLearner()(self.iris) + + no_class = Table(Domain(titanic.domain.attributes, None), titanic) + self.send_signal("Predictors", majority_titanic, 1) + self.send_signal("Data", no_class) + out = self.get_output("Predictions") + np.testing.assert_allclose(out.get_column_view("majority")[0], 0) + + self.send_signal("Predictors", majority_iris, 2) + self.assertTrue(error.predictors_target_mismatch.is_shown()) + self.assertFalse(error.data_target_mismatch.is_shown()) + self.assertIsNone(self.get_output("Predictions")) + + self.send_signal("Predictors", None, 2) + self.send_signal("Data", titanic) + out = self.get_output("Predictions") + np.testing.assert_allclose(out.get_column_view("majority")[0], 0)