From 6c0a8d12c77c626fbc6f7824655a65526757a63a Mon Sep 17 00:00:00 2001 From: janezd Date: Thu, 26 Jan 2017 22:11:41 +0100 Subject: [PATCH] KMeans: Fix crashes when underlying algorithm fails While fixing the table, improve its layout --- Orange/widgets/unsupervised/owkmeans.py | 117 +++++++++++------- .../unsupervised/tests/test_owkmeans.py | 63 +++++++++- 2 files changed, 136 insertions(+), 44 deletions(-) diff --git a/Orange/widgets/unsupervised/owkmeans.py b/Orange/widgets/unsupervised/owkmeans.py index c9c8d0fe535..34a25f103ae 100644 --- a/Orange/widgets/unsupervised/owkmeans.py +++ b/Orange/widgets/unsupervised/owkmeans.py @@ -2,7 +2,7 @@ import numpy as np from AnyQt.QtWidgets import QGridLayout, QSizePolicy, QTableView -from AnyQt.QtGui import QStandardItemModel, QStandardItem, QIntValidator +from AnyQt.QtGui import QStandardItemModel, QStandardItem, QIntValidator, QBrush from AnyQt.QtCore import Qt, QTimer from Orange.clustering import KMeans @@ -24,15 +24,18 @@ class OWKMeans(widget.OWWidget): outputs = [("Annotated Data", Table, widget.Default), ("Centroids", Table)] + class Error(widget.OWWidget.Error): + failed = widget.Msg("Clustering failed\nError: {}") + INIT_KMEANS, INIT_RANDOM = range(2) INIT_METHODS = "Initialize with KMeans++", "Random initialization" SILHOUETTE, INTERCLUSTER, DISTANCES = range(3) - SCORING_METHODS = [("Silhouette", lambda km: km.silhouette, False), + SCORING_METHODS = [("Silhouette", lambda km: km.silhouette, False, True), ("Inter-cluster distance", - lambda km: km.inter_cluster, True), + lambda km: km.inter_cluster, True, False), ("Distance to centroids", - lambda km: km.inertia, True)] + lambda km: km.inertia, True, False)] OUTPUT_CLASS, OUTPUT_ATTRIBUTE, OUTPUT_META = range(3) OUTPUT_METHODS = ("Class", "Feature", "Meta") @@ -150,7 +153,8 @@ def __init__(self): table.setSelectionMode(QTableView.SingleSelection) table.setSelectionBehavior(QTableView.SelectRows) table.verticalHeader().hide() - table.setItemDelegateForColumn(1, gui.TableBarItem(self)) + self.bar_delegate = gui.ColoredBarItemDelegate(self, color=Qt.cyan) + table.setItemDelegateForColumn(1, self.bar_delegate) table.setModel(self.table_model) table.selectionModel().selectionChanged.connect( self.table_item_selected) @@ -219,6 +223,7 @@ def run_optimization(self): try: self.controlArea.setDisabled(True) self.optimization_runs = [] + error = "" if not self.check_data_size(self.k_from, self.Error): return self.check_data_size(self.k_to, self.Warning) @@ -231,7 +236,15 @@ def run_optimization(self): for k in range(self.k_from, k_to + 1): progress.advance() kmeans.params["n_clusters"] = k - self.optimization_runs.append((k, kmeans(self.data))) + try: + self.optimization_runs.append((k, kmeans(self.data))) + except BaseException as exc: + error = str(exc) + self.optimization_runs.append((k, error)) + if all(isinstance(score, str) + for _, score in self.optimization_runs): + self.Error.failed(error) # Report just the last error + self.optimization_runs = [] finally: self.controlArea.setDisabled(False) self.show_results() @@ -240,11 +253,15 @@ def run_optimization(self): def cluster(self): if not self.check_data_size(self.k, self.Error): return - self.km = KMeans( - n_clusters=self.k, - init=['random', 'k-means++'][self.smart_init], - n_init=self.n_init, - max_iter=self.max_iterations)(self.data) + try: + self.km = KMeans( + n_clusters=self.k, + init=['random', 'k-means++'][self.smart_init], + n_init=self.n_init, + max_iter=self.max_iterations)(self.data) + except BaseException as exc: + self.Error.failed(str(exc)) + self.km = None self.send_data() def run(self): @@ -260,40 +277,55 @@ def commit(self): self.run() def show_results(self): - minimize = self.SCORING_METHODS[self.scoring][2] - k_scores = [(k, self.SCORING_METHODS[self.scoring][1](run)) for - k, run in self.optimization_runs] - scores = list(zip(*k_scores))[1] - if minimize: - best_score, worst_score = min(scores), max(scores) + _, scoring_method, minimize, normal = self.SCORING_METHODS[self.scoring] + k_scores = [(k, + scoring_method(run) if not isinstance(run, str) else run) + for k, run in self.optimization_runs] + scores = [score for _, score in k_scores if not isinstance(score, str)] + + min_score, max_score = min(scores, default=0), max(scores, default=1) + best_score = min_score if minimize else max_score + if normal: + min_score, max_score = 0, 1 + nplaces = 3 else: - best_score, worst_score = max(scores), min(scores) + nplaces = min(5, np.floor(abs(math.log(max(max_score, 1e-10)))) + 2) + score_span = (max_score - min_score) or 1 + self.bar_delegate.scale = (min_score, max_score) + self.bar_delegate.float_fmt = "%%.%if" % int(nplaces) - best_run = scores.index(best_score) - score_span = (best_score - worst_score) or 1 - max_score = max(scores) - nplaces = min(5, np.floor(abs(math.log(max(max_score, 1e-10)))) + 2) - fmt = "{{:.{}f}}".format(int(nplaces)) model = self.table_model model.setRowCount(len(k_scores)) + no_selection = True for i, (k, score) in enumerate(k_scores): - item = model.item(i, 0) - if item is None: - item = QStandardItem() - item.setData(k, Qt.DisplayRole) - item.setTextAlignment(Qt.AlignCenter) - model.setItem(i, 0, item) - item = model.item(i, 1) - if item is None: - item = QStandardItem() - item.setData(fmt.format(score) if not np.isnan(score) else 'out-of-memory error', - Qt.DisplayRole) - bar_ratio = 0.95 * (score - worst_score) / score_span - item.setData(bar_ratio, gui.TableBarItem.BarRole) + item0 = model.item(i, 0) or QStandardItem() + item0.setData(k, Qt.DisplayRole) + item0.setTextAlignment(Qt.AlignCenter) + model.setItem(i, 0, item0) + item = model.item(i, 1) or QStandardItem() + if not isinstance(score, str): + item.setData(score, Qt.DisplayRole) + item.setData(None, Qt.ToolTipRole) + bar_ratio = 0.95 * (score - min_score) / score_span + item.setData(bar_ratio, gui.BarRatioRole) + if no_selection and score == best_score: + self.table_view.selectRow(i) + no_selection = False + color = Qt.black + flags = Qt.ItemIsEnabled | Qt.ItemIsSelectable + else: + item.setData("clustering failed", Qt.DisplayRole) + item.setData(score, Qt.ToolTipRole) + item.setData(None, gui.BarRatioRole) + color = Qt.gray + flags = Qt.NoItemFlags + item0.setData(QBrush(color), Qt.ForegroundRole) + item0.setFlags(flags) + item.setData(QBrush(color), Qt.ForegroundRole) + item.setFlags(flags) model.setItem(i, 1, item) self.table_view.resizeRowsToContents() - self.table_view.selectRow(best_run) self.table_view.show() if minimize: self.table_box.setTitle("Scoring (smaller is better)") @@ -314,13 +346,12 @@ def selected_row(self): def table_item_selected(self): row = self.selected_row() if row is not None: - self.send_data(row) + self.send_data() - def send_data(self, row=None): + def send_data(self): if self.optimize_k: - if row is None: - row = self.selected_row() - km = self.optimization_runs[row][1] + row = self.selected_row() if self.optimization_runs else None + km = self.optimization_runs[row][1] if row is not None else None else: km = self.km if not self.data or not km: @@ -356,6 +387,8 @@ def send_data(self, row=None): def set_data(self, data): self.data = data if data is None: + self.Error.clear() + self.Warning.clear() self.table_model.setRowCount(0) self.send("Annotated Data", None) self.send("Centroids", None) diff --git a/Orange/widgets/unsupervised/tests/test_owkmeans.py b/Orange/widgets/unsupervised/tests/test_owkmeans.py index 76a0914ea52..a519f71acc3 100644 --- a/Orange/widgets/unsupervised/tests/test_owkmeans.py +++ b/Orange/widgets/unsupervised/tests/test_owkmeans.py @@ -1,14 +1,17 @@ +from unittest.mock import patch + from AnyQt.QtWidgets import QRadioButton from Orange.widgets.tests.base import WidgetTest from Orange.widgets.unsupervised.owkmeans import OWKMeans +import Orange.clustering from Orange.data import Table class TestOWKMeans(WidgetTest): def setUp(self): - self.widget = self.create_widget(OWKMeans, - stored_settings={"auto_apply": False}) + self.widget = self.create_widget( + OWKMeans, stored_settings={"auto_apply": False}) # type: OWKMeans self.iris = Table("iris") def test_optimization_report_display(self): @@ -32,3 +35,59 @@ def test_data_on_output(self): self.send_signal("Data", None) # removing data should have cleared the output self.assertEqual(self.widget.data, None) + + class KMeansFail(Orange.clustering.KMeans): + fail_on = set() + + def fit(self, *args): + # when not optimizing, params is empty?! + k = self.params.get("n_clusters", 3) + if k in self.fail_on: + raise ValueError("k={} fails".format(k)) + return super().fit(*args) + + @patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail) + def test_optimization_fails(self): + widget = self.widget + widget.k_from = 3 + widget.k_to = 8 + widget.scoring = 0 + widget.optimize_k = True + + self.KMeansFail.fail_on = {3, 5, 7} + self.send_signal("Data", self.iris) + self.assertIsInstance(widget.optimization_runs[0][1], str) + self.assertIsInstance(widget.optimization_runs[2][1], str) + self.assertIsInstance(widget.optimization_runs[4][1], str) + self.assertNotIsInstance(widget.optimization_runs[1][1], str) + self.assertNotIsInstance(widget.optimization_runs[3][1], str) + self.assertNotIsInstance(widget.optimization_runs[5][1], str) + self.assertFalse(widget.Error.failed.is_shown()) + self.assertEqual(widget.selected_row(), 1) + self.assertIsNotNone(self.get_output("Annotated Data")) + + self.KMeansFail.fail_on = set(range(3, 9)) + widget.run() + self.assertTrue(widget.Error.failed.is_shown()) + self.assertEqual(widget.optimization_runs, []) + self.assertIsNone(self.get_output("Annotated Data")) + + self.KMeansFail.fail_on = set() + widget.run() + self.assertFalse(widget.Error.failed.is_shown()) + self.assertEqual(widget.selected_row(), 0) + self.assertIsNotNone(self.get_output("Annotated Data")) + + @patch("Orange.widgets.unsupervised.owkmeans.KMeans", new=KMeansFail) + def test_run_fails(self): + self.widget.k = 3 + self.widget.optimize_k = False + self.KMeansFail.fail_on = {3} + self.send_signal("Data", self.iris) + self.assertTrue(self.widget.Error.failed.is_shown()) + self.assertIsNone(self.get_output("Annotated Data")) + + self.KMeansFail.fail_on = set() + self.widget.run() + self.assertFalse(self.widget.Error.failed.is_shown()) + self.assertIsNotNone(self.get_output("Annotated Data"))