Skip to content

Commit

Permalink
Merge pull request #2183 from janezd/fix-predictions-noclass
Browse files Browse the repository at this point in the history
[FIX] OWPredictions: Allow classification when data has no target column
  • Loading branch information
lanzagar authored Apr 6, 2017
2 parents 498e6b1 + 21ba467 commit 4b68ef5
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 44 deletions.
83 changes: 51 additions & 32 deletions Orange/widgets/evaluate/owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ class Warning(OWWidget.Warning):
empty_data = Msg("Empty data set")

class Error(OWWidget.Error):
predictor_failed = Msg("One or more predictors failed (see more...)\n{}")
predictor_failed = \
Msg("One or more predictors failed (see more...)\n{}")
predictors_target_mismatch = \
Msg("Predictors do not have the same target.")
data_target_mismatch = \
Msg("Data does not have the same target as predictors.")

settingsHandler = settings.ClassValuesContextHandler()
#: Display the full input dataset or only the target variable columns (if
Expand Down Expand Up @@ -182,32 +187,18 @@ def set_data(self, data):

self.data = data
if data is None:
self.class_var = class_var = None
self.dataview.setModel(None)
self.predictionsview.setModel(None)
self.predictionsview.setItemDelegate(PredictionsItemDelegate())
else:
# force full reset of the view's HeaderView state
self.class_var = class_var = data.domain.class_var
self.dataview.setModel(None)
model = TableModel(data, parent=None)
modelproxy = TableSortProxyModel()
modelproxy.setSourceModel(model)
self.dataview.setModel(modelproxy)
self._update_column_visibility()

discrete_class = class_var is not None and class_var.is_discrete
self.classification_options.setVisible(discrete_class)

self.closeContext()
if discrete_class:
self.class_values = list(class_var.values)
self.selected_classes = list(range(len(self.class_values)))
self.openContext(self.class_var)
else:
self.class_values = []
self.selected_classes = []

self._invalidate_predictions()

def set_predictor(self, predictor=None, id=None):
Expand All @@ -221,7 +212,36 @@ def set_predictor(self, predictor=None, id=None):
self.predictors[id] = \
PredictorSlot(predictor, predictor.name, None)

def set_class_var(self):
pred_classes = set(pred.predictor.domain.class_var
for pred in self.predictors.values())
self.Error.predictors_target_mismatch.clear()
self.Error.data_target_mismatch.clear()
self.class_var = None
if len(pred_classes) > 1:
self.Error.predictors_target_mismatch()
if len(pred_classes) == 1:
self.class_var = pred_classes.pop()
if self.data is not None and \
self.data.domain.class_var is not None and \
self.class_var != self.data.domain.class_var:
self.Error.data_target_mismatch()
self.class_var = None

discrete_class = self.class_var is not None \
and self.class_var.is_discrete
self.classification_options.setVisible(discrete_class)
self.closeContext()
if discrete_class:
self.class_values = list(self.class_var.values)
self.selected_classes = list(range(len(self.class_values)))
self.openContext(self.class_var)
else:
self.class_values = []
self.selected_classes = []

def handleNewSignals(self):
self.set_class_var()
if self.data is not None:
self._call_predictors()
self._update_predictions_model()
Expand All @@ -232,14 +252,9 @@ def handleNewSignals(self):

def _call_predictors(self):
for inputid, pred in self.predictors.items():
if pred.results is None:
if pred.results is None or numpy.isnan(pred.results[0]).all():
try:
predictor_class = pred.predictor.domain.class_var
if predictor_class != self.class_var:
results = "{}: mismatching target ({})".format(
pred.predictor.name, predictor_class.name)
else:
results = self.predict(pred.predictor, self.data)
results = self.predict(pred.predictor, self.data)
except ValueError as err:
results = "{}: {}".format(pred.predictor.name, err)
self.predictors[inputid] = pred._replace(results=results)
Expand Down Expand Up @@ -285,12 +300,16 @@ def _invalidate_predictions(self):
self.predictors[inputid] = pred._replace(results=None)

def _valid_predictors(self):
return [p for p in self.predictors.values()
if p.results is not None and not isinstance(p.results, str)]
if self.class_var is not None and \
self.data is not None:
return [p for p in self.predictors.values()
if p.results is not None and not isinstance(p.results, str)]
else:
return []

def _update_predictions_model(self):
"""Update the prediction view model."""
if self.data is not None:
if self.data is not None and self.class_var is not None:
slots = self._valid_predictors()
results = []
class_var = self.class_var
Expand Down Expand Up @@ -323,7 +342,7 @@ def _update_predictions_model(self):

def _update_column_visibility(self):
"""Update data column visibility."""
if self.data is not None:
if self.data is not None and self.class_var is not None:
domain = self.data.domain
first_attr = len(domain.class_vars) + len(domain.metas)

Expand Down Expand Up @@ -415,12 +434,12 @@ def commit(self):
self._commit_evaluation_results()

def _commit_evaluation_results(self):
class_var = self.class_var
slots = self._valid_predictors()
if not slots:
if not slots or self.data.domain.class_var is None:
self.send("Evaluation Results", None)
return

class_var = self.class_var
nanmask = numpy.isnan(self.data.get_column_view(class_var)[0])
data = self.data[~nanmask]
N = len(data)
Expand All @@ -442,15 +461,15 @@ def _commit_predictions(self):
self.send("Predictions", None)
return

class_var = self.class_var
if class_var and class_var.is_discrete:
if self.class_var and self.class_var.is_discrete:
newmetas, newcolumns = self._classification_output_columns()
else:
newmetas, newcolumns = self._regression_output_columns()

attrs = list(self.data.domain.attributes) if self.output_attrs else []
metas = list(self.data.domain.metas) + newmetas
domain = Orange.data.Domain(attrs, class_var, metas=metas)
domain = \
Orange.data.Domain(attrs, self.data.domain.class_var, metas=metas)
predictions = self.data.from_table(domain, self.data)
if newcolumns:
newcolumns = numpy.hstack(
Expand Down Expand Up @@ -506,7 +525,7 @@ def merge_data_with_predictions():
[data_model.data(data_model.index(i, j))
for j in iter_data_cols]

if self.data is not None:
if self.data is not None and self.class_var is not None:
text = self.infolabel.text().replace('\n', '<br>')
if self.show_probabilities and self.selected_classes:
text += '<br>Showing probabilities for: '
Expand Down
57 changes: 45 additions & 12 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.evaluate.owpredictions import OWPredictions

from Orange.data import Table
from Orange.data import Table, Domain
from Orange.classification import MajorityLearner
from Orange.evaluation import Results

Expand Down Expand Up @@ -47,42 +47,75 @@ def test_nan_target_input(self):
self.assertEqual(len(evres.data), 0)

def test_mismatching_targets(self):
error = self.widget.Error

titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)

self.send_signal("Data", self.iris)
self.send_signal("Predictors", majority_iris, 1)
self.send_signal("Predictors", majority_titanic, 2)
self.assertTrue(self.widget.Error.predictor_failed.is_shown())
output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 4)
self.assertTrue(error.predictors_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", None, 1)
self.assertTrue(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertTrue(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Data", None)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", None, 2)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", majority_titanic, 2)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Data", self.iris)
self.assertTrue(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertTrue(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", majority_iris, 2)
self.assertFalse(self.widget.Error.predictor_failed.is_shown())
self.assertFalse(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 4)

self.send_signal("Predictors", majority_iris, 1)
self.send_signal("Predictors", majority_titanic, 3)
output = self.get_output("Predictions")
self.assertEqual(len(output.domain.metas), 8)
self.assertTrue(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

def test_no_class_on_test(self):
"""Allow test data with no class"""
error = self.widget.Error

titanic = Table("titanic")
majority_titanic = MajorityLearner()(titanic)
majority_iris = MajorityLearner()(self.iris)

no_class = Table(Domain(titanic.domain.attributes, None), titanic)
self.send_signal("Predictors", majority_titanic, 1)
self.send_signal("Data", no_class)
out = self.get_output("Predictions")
np.testing.assert_allclose(out.get_column_view("majority")[0], 0)

self.send_signal("Predictors", majority_iris, 2)
self.assertTrue(error.predictors_target_mismatch.is_shown())
self.assertFalse(error.data_target_mismatch.is_shown())
self.assertIsNone(self.get_output("Predictions"))

self.send_signal("Predictors", None, 2)
self.send_signal("Data", titanic)
out = self.get_output("Predictions")
np.testing.assert_allclose(out.get_column_view("majority")[0], 0)

0 comments on commit 4b68ef5

Please sign in to comment.