From 28a22ee0456305e89584fb5730dbee00adc42c23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Mon, 11 Jul 2016 16:04:20 +0200 Subject: [PATCH 01/12] Pythagorean tree: Integrate pythagorean tree and forest widget into orange --- .../visualize/icons/PythagoreanForest.svg | 22 + .../visualize/icons/PythagoreanTree.svg | 20 + Orange/widgets/visualize/owpythagorastree.py | 610 ++++++++++++++ .../widgets/visualize/owpythagoreanforest.py | 415 ++++++++++ .../widgets/visualize/pythagorastreeviewer.py | 756 ++++++++++++++++++ .../visualize/tests/test_owpythagorastree.py | 59 ++ .../widgets/visualize/widgetutils/__init__.py | 0 .../visualize/widgetutils/common/__init__.py | 0 .../visualize/widgetutils/common/owgrid.py | 280 +++++++ .../visualize/widgetutils/common/owlegend.py | 644 +++++++++++++++ .../visualize/widgetutils/common/scene.py | 25 + .../visualize/widgetutils/common/view.py | 180 +++++ .../visualize/widgetutils/tree/__init__.py | 0 .../visualize/widgetutils/tree/rules.py | 181 +++++ .../widgetutils/tree/skltreeadapter.py | 302 +++++++ .../widgetutils/tree/tests/__init__.py | 0 .../widgetutils/tree/tests/test_rules.py | 157 ++++ .../visualize/widgetutils/tree/treeadapter.py | 274 +++++++ 18 files changed, 3925 insertions(+) create mode 100644 Orange/widgets/visualize/icons/PythagoreanForest.svg create mode 100644 Orange/widgets/visualize/icons/PythagoreanTree.svg create mode 100644 Orange/widgets/visualize/owpythagorastree.py create mode 100644 Orange/widgets/visualize/owpythagoreanforest.py create mode 100644 Orange/widgets/visualize/pythagorastreeviewer.py create mode 100644 Orange/widgets/visualize/tests/test_owpythagorastree.py create mode 100644 Orange/widgets/visualize/widgetutils/__init__.py create mode 100644 Orange/widgets/visualize/widgetutils/common/__init__.py create mode 100644 Orange/widgets/visualize/widgetutils/common/owgrid.py create mode 100644 Orange/widgets/visualize/widgetutils/common/owlegend.py create mode 100644 Orange/widgets/visualize/widgetutils/common/scene.py create mode 100644 Orange/widgets/visualize/widgetutils/common/view.py create mode 100644 Orange/widgets/visualize/widgetutils/tree/__init__.py create mode 100644 Orange/widgets/visualize/widgetutils/tree/rules.py create mode 100644 Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py create mode 100644 Orange/widgets/visualize/widgetutils/tree/tests/__init__.py create mode 100644 Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py create mode 100644 Orange/widgets/visualize/widgetutils/tree/treeadapter.py diff --git a/Orange/widgets/visualize/icons/PythagoreanForest.svg b/Orange/widgets/visualize/icons/PythagoreanForest.svg new file mode 100644 index 00000000000..d612d9cb8cc --- /dev/null +++ b/Orange/widgets/visualize/icons/PythagoreanForest.svg @@ -0,0 +1,22 @@ + + + + + + + + + + + + + + + + + + + diff --git a/Orange/widgets/visualize/icons/PythagoreanTree.svg b/Orange/widgets/visualize/icons/PythagoreanTree.svg new file mode 100644 index 00000000000..10aba2963ae --- /dev/null +++ b/Orange/widgets/visualize/icons/PythagoreanTree.svg @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py new file mode 100644 index 00000000000..2ce4972e22a --- /dev/null +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -0,0 +1,610 @@ +"""Pythagorean tree viewer for visualizing trees. + +The widget currently supports 3 different kinds of trees that may come in handy +though only 2 are implemented - (GENERAL, CLASSIFICATION, REGRESSION) with +(CLASSIFICATION and REGRESSION implemented. + +These modes are necessary since it is impossible to have one tree viewer for +both classification and regression type trees, since they need specialized +tootips and colors. + +The general tree exists for the purpose of easy extension, since a new `Tree` +type has been introduced to Orange, which this widget accepts (both the +classification and regression tree extend that base Tree class). In case the +widget is ever required to support more general trees, the enum type is there +to be used, and the methods need to be implemented, then it should work for any +kind of trees. + +""" +from math import sqrt, log + +import numpy as np +from Orange.widgets.visualize.widgetutils.common.scene import \ + UpdateItemsOnSelectGraphicsScene +from Orange.widgets.visualize.widgetutils.common.view import ( + PannableGraphicsView, + ZoomableGraphicsView, + PreventDefaultWheelEvent +) +from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \ + SklTreeAdapter +from PyQt4 import QtGui + +from Orange.base import Tree +from Orange.classification.tree import TreeClassifier +from Orange.data.table import Table +from Orange.regression.tree import TreeRegressor +from Orange.widgets import gui, settings +from Orange.widgets.utils.colorpalette import ContinuousPaletteGenerator +from Orange.widgets.visualize.pythagorastreeviewer import ( + PythagorasTreeViewer, + SquareGraphicsItem +) +from Orange.widgets.visualize.widgetutils.common.owlegend import ( + AnchorableGraphicsView, + Anchorable, + OWDiscreteLegend, + OWContinuousLegend +) +from Orange.widgets.widget import OWWidget + + +class OWPythagorasTree(OWWidget): + name = 'Pythagorean Tree' + description = 'Pythagorean Tree visualization for tree like-structures.' + icon = 'icons/PythagoreanTree.svg' + + priority = 610 + + inputs = [('Tree', Tree, 'set_tree')] + outputs = [('Selected Data', Table)] + + # Enable the save as feature + graph_name = 'scene' + + # Settings + depth_limit = settings.ContextSetting(10) + target_class_index = settings.ContextSetting(0) + size_calc_idx = settings.Setting(0) + size_log_scale = settings.Setting(2) + tooltips_enabled = settings.Setting(True) + show_legend = settings.Setting(False) + + GENERAL, CLASSIFICATION, REGRESSION = range(3) + + LEGEND_OPTIONS = { + 'corner': Anchorable.BOTTOM_RIGHT, + 'offset': (10, 10), + } + + def __init__(self): + super().__init__() + # Instance variables + self.tree_type = self.GENERAL + self.model = None + self.instances = None + self.clf_dataset = None + # The tree adapter instance which is passed from the outside + self.tree_adapter = None + self.legend = None + + self.color_palette = None + + # Different methods to calculate the size of squares + self.SIZE_CALCULATION = [ + ('Normal', lambda x: x), + ('Square root', lambda x: sqrt(x)), + # The +1 is there so that we don't get division by 0 exceptions + ('Logarithmic', lambda x: log(x * self.size_log_scale + 1)), + ] + + # Color modes for regression trees + self.REGRESSION_COLOR_CALC = [ + ('None', lambda _, __: QtGui.QColor(255, 255, 255)), + ('Class mean', self._color_class_mean), + ('Standard deviation', self._color_stddev), + ] + + # CONTROL AREA + # Tree info area + box_info = gui.widgetBox(self.controlArea, 'Tree Info') + self.info = gui.widgetLabel(box_info, label='') + + # Display settings area + box_display = gui.widgetBox(self.controlArea, 'Display Settings') + self.depth_slider = gui.hSlider( + box_display, self, 'depth_limit', label='Depth', ticks=False, + callback=self.update_depth) + self.target_class_combo = gui.comboBox( + box_display, self, 'target_class_index', label='Target class', + orientation='horizontal', items=[], contentsLength=8, + callback=self.update_colors) + self.size_calc_combo = gui.comboBox( + box_display, self, 'size_calc_idx', label='Size', + orientation='horizontal', + items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8, + callback=self.update_size_calc) + self.log_scale_box = gui.hSlider( + box_display, self, 'size_log_scale', + label='Log scale factor', minValue=1, maxValue=100, ticks=False, + callback=self.invalidate_tree) + + # Plot properties area + box_plot = gui.widgetBox(self.controlArea, 'Plot Properties') + gui.checkBox( + box_plot, self, 'tooltips_enabled', label='Enable tooltips', + callback=self.update_tooltip_enabled) + gui.checkBox( + box_plot, self, 'show_legend', label='Show legend', + callback=self.update_show_legend) + + # Stretch to fit the rest of the unsused area + gui.rubber(self.controlArea) + + self.controlArea.setSizePolicy( + QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Expanding) + + # MAIN AREA + # The QGraphicsScene doesn't actually require a parent, but not linking + # the widget to the scene causes errors and a segfault on close due to + # the way Qt deallocates memory and deletes objects. + self.scene = TreeGraphicsScene(self) + self.scene.selectionChanged.connect(self.commit) + self.view = TreeGraphicsView(self.scene, padding=(150, 150)) + self.view.setRenderHint(QtGui.QPainter.Antialiasing, True) + self.mainArea.layout().addWidget(self.view) + + self.ptree = PythagorasTreeViewer() + self.scene.addItem(self.ptree) + self.view.set_central_widget(self.ptree) + + self.resize(800, 500) + # Clear the widget to correctly set the intial values + self.clear() + + def set_tree(self, model=None): + """When a different tree is given.""" + self.clear() + self.model = model + + if model is not None: + # We need to know what kind of tree we have in order to properly + # show colors and tooltips + if isinstance(model, TreeClassifier): + self.tree_type = self.CLASSIFICATION + elif isinstance(model, TreeRegressor): + self.tree_type = self.REGRESSION + else: + self.tree_type = self.GENERAL + + self.instances = model.instances + # this bit is important for the regression classifier + if self.instances is not None and \ + self.instances.domain != model.domain: + self.clf_dataset = Table.from_table( + self.model.domain, self.instances) + else: + self.clf_dataset = self.instances + + self.tree_adapter = self._get_tree_adapter(self.model) + self.color_palette = self._tree_specific('_get_color_palette')() + + self.ptree.clear() + self.ptree.set_tree(self.tree_adapter) + self.ptree.set_tooltip_func(self._tree_specific('_get_tooltip')) + self.ptree.set_node_color_func( + self._tree_specific('_get_node_color') + ) + + self._tree_specific('_update_legend_colors')() + self._update_legend_visibility() + + self._update_info_box() + self._update_depth_slider() + + self._tree_specific('_update_target_class_combo')() + + self._update_main_area() + + # Get meta variables describing pythagoras tree if given from + # forest. + if hasattr(model, 'meta_size_calc_idx'): + self.size_calc_idx = model.meta_size_calc_idx + if hasattr(model, 'meta_size_log_scale'): + self.size_log_scale = model.meta_size_log_scale + # Updating the size calc redraws the whole tree + if hasattr(model, 'meta_size_calc_idx') or \ + hasattr(model, 'meta_size_log_scale'): + self.update_size_calc() + # The target class can also be passed from the meta properties + if hasattr(model, 'meta_target_class_index'): + self.target_class_index = model.meta_target_class_index + self.update_colors() + # TODO this messes up the viewport in pythagoras tree viewer + # it seems the viewport doesn't reset its size if this is applied + # if hasattr(model, 'meta_depth_limit'): + # self.depth_limit = model.meta_depth_limit + # self.update_depth() + + def clear(self): + """Clear all relevant data from the widget.""" + self.model = None + self.instances = None + self.clf_dataset = None + self.tree_adapter = None + + if self.legend is not None: + self.scene.removeItem(self.legend) + self.legend = None + + self.ptree.clear() + self._clear_info_box() + self._clear_target_class_combo() + self._clear_depth_slider() + self._update_log_scale_slider() + + # CONTROL AREA CALLBACKS + def update_depth(self): + """This method should be called when the depth changes""" + self.ptree.set_depth_limit(self.depth_limit) + + def update_colors(self): + self.ptree.target_class_has_changed() + self._tree_specific('_update_legend_colors')() + + def update_size_calc(self): + self._update_log_scale_slider() + self.invalidate_tree() + + def invalidate_tree(self): + """When the tree needs to be recalculated. E.g. change of size calc.""" + if self.model is not None: + self.tree_adapter = self._get_tree_adapter(self.model) + self.ptree.set_tree(self.tree_adapter) + self.ptree.set_depth_limit(self.depth_limit) + self._update_main_area() + + def update_tooltip_enabled(self): + if self.tooltips_enabled: + self.ptree.set_tooltip_func( + self._tree_specific('_get_tooltip') + ) + else: + self.ptree.set_tooltip_func(lambda _: None) + self.ptree.tooltip_has_changed() + + def update_show_legend(self): + self._update_legend_visibility() + + # MODEL CHANGED CONTROL ELEMENTS UPDATE METHODS + def _update_info_box(self): + self.info.setText('Nodes: {}\nDepth: {}'.format( + self.tree_adapter.num_nodes, + self.tree_adapter.max_depth + )) + + def _update_depth_slider(self): + self.depth_slider.parent().setEnabled(True) + self.depth_slider.setMaximum(self.tree_adapter.max_depth) + self._set_max_depth() + + def _update_legend_visibility(self): + if self.legend is not None: + self.legend.setVisible(self.show_legend) + + def _update_log_scale_slider(self): + """On calc method combo box changed.""" + if self.SIZE_CALCULATION[self.size_calc_idx][0] == 'Logarithmic': + self.log_scale_box.parent().setEnabled(True) + else: + self.log_scale_box.parent().setEnabled(False) + + # MODEL REMOVED CONTROL ELEMENTS CLEAR METHODS + def _clear_info_box(self): + self.info.setText('No tree on input') + + def _clear_depth_slider(self): + self.depth_slider.parent().setEnabled(False) + self.depth_slider.setMaximum(0) + + def _clear_target_class_combo(self): + self.target_class_combo.clear() + self.target_class_index = 0 + self.target_class_combo.setCurrentIndex(self.target_class_index) + + # HELPFUL METHODS + def _set_max_depth(self): + """Set the depth to the max depth and update appropriate actors.""" + self.depth_limit = self.tree_adapter.max_depth + self.depth_slider.setValue(self.depth_limit) + + def _update_main_area(self): + # refresh the scene rect, cuts away the excess whitespace, and adds + # padding for panning. + self.scene.setSceneRect(self.view.central_widget_rect()) + # reset the zoom level + self.view.recalculate_and_fit() + self.view.update_anchored_items() + + def _get_tree_adapter(self, model): + return SklTreeAdapter( + model.tree, + model.domain, + adjust_weight=self.SIZE_CALCULATION[self.size_calc_idx][1], + ) + + def onDeleteWidget(self): + """When deleting the widget.""" + super().onDeleteWidget() + self.clear() + + def commit(self): + """Commit the selected data to output.""" + if self.instances is None: + self.send('Selected Data', None) + return + # this is taken almost directly from the owclassificationtreegraph.py + items = filter(lambda x: isinstance(x, SquareGraphicsItem), + self.scene.selectedItems()) + + data = self.tree_adapter.get_instances_in_nodes( + self.clf_dataset, [item.tree_node for item in items]) + self.send('Selected Data', data) + + def send_report(self): + self.report_plot() + + def _tree_specific(self, method): + """A best effort method getter that somewhat separates logic specific + to classification and regression trees. + This relies on conventional naming of specific methods, e.g. + a method name _get_tooltip would need to be defined like so: + _classification_get_tooltip and _regression_get_tooltip, since they are + both specific. + + Parameters + ---------- + method : str + Method name that we would like to call. + + Returns + ------- + callable or None + + """ + if self.tree_type == self.GENERAL: + return getattr(self, '_general' + method) + elif self.tree_type == self.CLASSIFICATION: + return getattr(self, '_classification' + method) + elif self.tree_type == self.REGRESSION: + return getattr(self, '_regression' + method) + else: + return None + + # CLASSIFICATION TREE SPECIFIC METHODS + def _classification_update_target_class_combo(self): + self._clear_target_class_combo() + self.target_class_combo.addItem('None') + values = [c.title() for c in + self.tree_adapter.domain.class_vars[0].values] + self.target_class_combo.addItems(values) + + def _classification_update_legend_colors(self): + if self.legend is not None: + self.scene.removeItem(self.legend) + + if self.target_class_index == 0: + self.legend = OWDiscreteLegend(domain=self.model.domain, + **self.LEGEND_OPTIONS) + else: + items = ( + (self.target_class_combo.itemText(self.target_class_index), + self.color_palette[self.target_class_index - 1] + ), + ('other', QtGui.QColor('#ffffff')) + ) + self.legend = OWDiscreteLegend(items=items, **self.LEGEND_OPTIONS) + + self.legend.setVisible(self.show_legend) + self.scene.addItem(self.legend) + + def _classification_get_color_palette(self): + return [QtGui.QColor(*c) for c in self.model.domain.class_var.colors] + + def _classification_get_node_color(self, adapter, tree_node): + # this is taken almost directly from the existing classification tree + # viewer + colors = self.color_palette + distribution = adapter.get_distribution(tree_node.label)[0] + total = np.sum(distribution) + + if self.target_class_index: + p = distribution[self.target_class_index - 1] / total + color = colors[self.target_class_index - 1].light(200 - 100 * p) + else: + modus = np.argmax(distribution) + p = distribution[modus] / (total or 1) + color = colors[int(modus)].light(400 - 300 * p) + return color + + def _classification_get_tooltip(self, node): + distribution = self.tree_adapter.get_distribution(node.label)[0] + total = int(np.sum(distribution)) + if self.target_class_index: + samples = distribution[self.target_class_index - 1] + text = '' + else: + modus = np.argmax(distribution) + samples = distribution[modus] + text = self.tree_adapter.domain.class_vars[0].values[modus] + \ + '
' + ratio = samples / np.sum(distribution) + + rules = self.tree_adapter.rules(node.label) + sorted_rules = sorted(rules[:-1], key=lambda rule: rule.attr_name) + rules_str = '' + if len(rules): + rules_str += '
'.join(str(rule) for rule in sorted_rules) + rules_str += '
%s' % rules[-1] + + splitting_attr = self.tree_adapter.attribute(node.label) + + return '

' \ + + text \ + + '{}/{} samples ({:2.3f}%)'.format( + int(samples), total, ratio * 100) \ + + '


' \ + + ('Split by ' + splitting_attr.name + if not self.tree_adapter.is_leaf(node.label) else '') \ + + ('

' if len(rules) and not self.tree_adapter.is_leaf( + node.label) else '') \ + + rules_str \ + + '

' + + # REGRESSION TREE SPECIFIC METHODS + def _regression_update_target_class_combo(self): + self._clear_target_class_combo() + self.target_class_combo.addItems( + list(zip(*self.REGRESSION_COLOR_CALC))[0]) + self.target_class_combo.setCurrentIndex(self.target_class_index) + + def _regression_update_legend_colors(self): + if self.legend is not None: + self.scene.removeItem(self.legend) + + def _get_colors_domain(domain): + class_var = domain.class_var + start, end, pass_through_black = class_var.colors + if pass_through_black: + lst_colors = [QtGui.QColor(*c) for c + in [start, (0, 0, 0), end]] + else: + lst_colors = [QtGui.QColor(*c) for c in [start, end]] + return lst_colors + + # Currently, the first index just draws the outline without any color + if self.target_class_index == 0: + self.legend = None + return + # The colors are the class mean + elif self.target_class_index == 1: + values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y)) + colors = _get_colors_domain(self.model.domain) + while len(values) != len(colors): + values.insert(1, -1) + + self.legend = OWContinuousLegend(items=list(zip(values, colors)), + **self.LEGEND_OPTIONS) + # Colors are the stddev + elif self.target_class_index == 2: + values = (0, np.std(self.clf_dataset.Y)) + colors = _get_colors_domain(self.model.domain) + while len(values) != len(colors): + values.insert(1, -1) + + self.legend = OWContinuousLegend(items=list(zip(values, colors)), + **self.LEGEND_OPTIONS) + + self.legend.setVisible(self.show_legend) + self.scene.addItem(self.legend) + + def _regression_get_color_palette(self): + return ContinuousPaletteGenerator( + *self.tree_adapter.domain.class_var.colors) + + def _regression_get_node_color(self, adapter, tree_node): + return self.REGRESSION_COLOR_CALC[self.target_class_index][1]( + adapter, tree_node + ) + + def _color_class_mean(self, adapter, tree_node): + # calculate node colors relative to the mean of the node samples + min_mean = np.min(self.clf_dataset.Y) + max_mean = np.max(self.clf_dataset.Y) + instances = adapter.get_instances_in_nodes(self.clf_dataset, tree_node) + mean = np.mean(instances.Y) + + return self.color_palette[(mean - min_mean) / (max_mean - min_mean)] + + def _color_stddev(self, adapter, tree_node): + # calculate node colors relative to the standard deviation in the node + # samples + min_mean, max_mean = 0, np.std(self.clf_dataset.Y) + instances = adapter.get_instances_in_nodes(self.clf_dataset, tree_node) + std = np.std(instances.Y) + + return self.color_palette[(std - min_mean) / (max_mean - min_mean)] + + def _regression_get_tooltip(self, node): + total = self.tree_adapter.num_samples( + self.tree_adapter.parent(node.label)) + samples = self.tree_adapter.num_samples(node.label) + ratio = samples / total + + instances = self.tree_adapter.get_instances_in_nodes( + self.clf_dataset, node) + mean = np.mean(instances.Y) + std = np.std(instances.Y) + + rules = self.tree_adapter.rules(node.label) + sorted_rules = sorted(rules[:-1], key=lambda rule: rule.attr_name) + rules_str = '' + if len(rules): + rules_str += '
'.join(str(rule) for rule in sorted_rules) + rules_str += '
%s' % rules[-1] + + splitting_attr = self.tree_adapter.attribute(node.label) + + return '

Mean: {:2.3f}'.format(mean) \ + + '
Standard deviation: {:2.3f}'.format(std) \ + + '
{}/{} samples ({:2.3f}%)'.format( + int(samples), total, ratio * 100) \ + + '


' \ + + ('Split by ' + splitting_attr.name + if not self.tree_adapter.is_leaf(node.label) else '') \ + + ('

' if len(rules) and not self.tree_adapter.is_leaf( + node.label) else '') \ + + rules_str \ + + '

' + + +class TreeGraphicsView( + PannableGraphicsView, + ZoomableGraphicsView, + AnchorableGraphicsView, + PreventDefaultWheelEvent +): + pass + + +class TreeGraphicsScene(UpdateItemsOnSelectGraphicsScene): + pass + + +def main(): + import sys + import Orange + from Orange.classification.tree import TreeLearner + + argv = sys.argv + if len(argv) > 1: + filename = argv[1] + else: + filename = 'iris' + + app = QtGui.QApplication(argv) + ow = OWPythagorasTree() + data = Orange.data.Table(filename) + clf = TreeLearner(max_depth=1000)(data) + ow.set_tree(clf) + + ow.show() + ow.raise_() + ow.handleNewSignals() + app.exec_() + + sys.exit(0) + + +if __name__ == '__main__': + main() diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py new file mode 100644 index 00000000000..d007aa61eaf --- /dev/null +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -0,0 +1,415 @@ +from math import log, sqrt + +import numpy as np +from PyQt4 import QtGui +from PyQt4.QtCore import Qt + +from Orange.base import RandomForest, Tree +from Orange.classification.random_forest import RandomForestClassifier +from Orange.classification.tree import TreeClassifier +from Orange.data import Table +from Orange.regression.random_forest import RandomForestRegressor +from Orange.regression.tree import TreeRegressor +from Orange.widgets import gui, settings +from Orange.widgets.utils.colorpalette import ContinuousPaletteGenerator +from Orange.widgets.visualize.pythagorastreeviewer import PythagorasTreeViewer +from Orange.widgets.visualize.widgetutils.common.owgrid import ( + OWGrid, + SelectableGridItem, + ZoomableGridItem +) +from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \ + SklTreeAdapter +from Orange.widgets.widget import OWWidget + + +class OWPythagoreanForest(OWWidget): + name = 'Pythagorean forest' + description = 'Pythagorean forest for visualising random forests.' + icon = 'icons/PythagoreanForest.svg' + + priority = 620 + + inputs = [('Random forest', RandomForest, 'set_rf')] + outputs = [('Tree', Tree)] + + # Enable the save as feature + graph_name = 'scene' + + # Settings + depth_limit = settings.ContextSetting(10) + target_class_index = settings.ContextSetting(0) + size_calc_idx = settings.Setting(0) + size_log_scale = settings.Setting(2) + zoom = settings.Setting(50) + selected_tree_index = settings.ContextSetting(-1) + + CLASSIFICATION, REGRESSION = range(2) + + def __init__(self): + super().__init__() + # Instance variables + self.forest_type = self.CLASSIFICATION + self.model = None + self.forest_adapter = None + self.dataset = None + self.clf_dataset = None + # We need to store refernces to the trees and grid items + self.grid_items, self.ptrees = [], [] + + self.color_palette = None + + # Different methods to calculate the size of squares + self.SIZE_CALCULATION = [ + ('Normal', lambda x: x), + ('Square root', lambda x: sqrt(x)), + ('Logarithmic', lambda x: log(x * self.size_log_scale)), + ] + + self.REGRESSION_COLOR_CALC = [ + ('None', lambda _, __: QtGui.QColor(255, 255, 255)), + ('Class mean', self._color_class_mean), + ('Standard deviation', self._color_stddev), + ] + + # CONTROL AREA + # Tree info area + box_info = gui.widgetBox(self.controlArea, 'Forest') + self.ui_info = gui.widgetLabel(box_info, label='') + + # Display controls area + box_display = gui.widgetBox(self.controlArea, 'Display') + self.ui_depth_slider = gui.hSlider( + box_display, self, 'depth_limit', label='Depth', ticks=False, + callback=self.max_depth_changed) + self.ui_target_class_combo = gui.comboBox( + box_display, self, 'target_class_index', label='Target class', + orientation='horizontal', items=[], contentsLength=8, + callback=self.target_colors_changed) + self.ui_size_calc_combo = gui.comboBox( + box_display, self, 'size_calc_idx', label='Size', + orientation='horizontal', + items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8, + callback=self.size_calc_changed) + self.ui_zoom_slider = gui.hSlider( + box_display, self, 'zoom', label='Zoom', ticks=False, minValue=20, + maxValue=150, callback=self.zoom_changed, createLabel=False) + + # Stretch to fit the rest of the unsused area + gui.rubber(self.controlArea) + + self.controlArea.setSizePolicy( + QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Expanding) + + # MAIN AREA + self.scene = QtGui.QGraphicsScene(self) + self.scene.selectionChanged.connect(self.commit) + self.grid = OWGrid() + self.grid.geometryChanged.connect(self._update_scene_rect) + self.scene.addItem(self.grid) + + self.view = QtGui.QGraphicsView(self.scene) + self.view.setRenderHint(QtGui.QPainter.Antialiasing, True) + self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn) + self.mainArea.layout().addWidget(self.view) + + self.resize(800, 500) + + self.clear() + + def set_rf(self, model=None): + """When a different forest is given.""" + self.clear() + self.model = model + + if model is not None: + if isinstance(model, RandomForestClassifier): + self.forest_type = self.CLASSIFICATION + elif isinstance(model, RandomForestRegressor): + self.forest_type = self.REGRESSION + else: + raise RuntimeError('Invalid type of forest.') + + self.forest_adapter = self._get_forest_adapter(self.model) + self.color_palette = self._type_specific('_get_color_palette')() + self._draw_trees() + + self.dataset = model.instances + # this bit is important for the regression classifier + if self.dataset is not None and \ + self.dataset.domain != model.domain: + self.clf_dataset = Table.from_table( + self.model.domain, self.dataset) + else: + self.clf_dataset = self.dataset + + self._update_info_box() + self._type_specific('_update_target_class_combo')() + self._update_depth_slider() + + self.selected_tree_index = -1 + + def clear(self): + """Clear all relevant data from the widget.""" + self.model = None + self.forest_adapter = None + self.ptrees = [] + self.grid_items = [] + self.grid.clear() + + self._clear_info_box() + self._clear_target_class_combo() + self._clear_depth_slider() + + # CONTROL AREA CALLBACKS + def max_depth_changed(self): + for tree in self.ptrees: + tree.set_depth_limit(self.depth_limit) + + def target_colors_changed(self): + for tree in self.ptrees: + tree.target_class_has_changed() + + def size_calc_changed(self): + if self.model is not None: + self.forest_adapter = self._get_forest_adapter(self.model) + self.grid.clear() + self._draw_trees() + # Keep the selected item + if self.selected_tree_index != -1: + self.grid_items[self.selected_tree_index].setSelected(True) + self.max_depth_changed() + + def zoom_changed(self): + for item in self.grid_items: + item.set_max_size(self._calculate_zoom(self.zoom)) + + width = (self.view.width() - self.view.verticalScrollBar().width()) + self.grid.reflow(width) + self.grid.setPreferredWidth(width) + + # MODEL CHANGED METHODS + def _update_info_box(self): + self.ui_info.setText( + 'Trees: {}'.format(len(self.forest_adapter.get_trees())) + ) + + def _update_depth_slider(self): + self.depth_limit = self._get_max_depth() + + self.ui_depth_slider.parent().setEnabled(True) + self.ui_depth_slider.setMaximum(self.depth_limit) + self.ui_depth_slider.setValue(self.depth_limit) + + # MODEL CLEARED METHODS + def _clear_info_box(self): + self.ui_info.setText('No forest on input.') + + def _clear_target_class_combo(self): + self.ui_target_class_combo.clear() + self.target_class_index = 0 + self.ui_target_class_combo.setCurrentIndex(self.target_class_index) + + def _clear_depth_slider(self): + self.ui_depth_slider.parent().setEnabled(False) + self.ui_depth_slider.setMaximum(0) + + # HELPFUL METHODS + def _get_max_depth(self): + return max([tree.tree_adapter.max_depth for tree in self.ptrees]) + + def _get_forest_adapter(self, model): + return SklRandomForestAdapter( + model, + model.domain, + adjust_weight=self.SIZE_CALCULATION[self.size_calc_idx][1], + ) + + def _draw_trees(self): + self.grid_items, self.ptrees = [], [] + + with self.progressBar(len(self.forest_adapter.get_trees())) as prg: + for tree in self.forest_adapter.get_trees(): + ptree = PythagorasTreeViewer( + None, tree, + node_color_func=self._type_specific('_get_node_color'), + interactive=False, padding=100) + self.grid_items.append(GridItem( + ptree, self.grid, max_size=self._calculate_zoom(self.zoom) + )) + self.ptrees.append(ptree) + prg.advance() + self.grid.set_items(self.grid_items) + # This is necessary when adding items for the first time + if self.grid: + width = (self.view.width() - + self.view.verticalScrollBar().width()) + self.grid.reflow(width) + self.grid.setPreferredWidth(width) + + @staticmethod + def _calculate_zoom(zoom_level): + """Calculate the max size for grid items from zoom level setting.""" + return zoom_level * 5 + + def onDeleteWidget(self): + """When deleting the widget.""" + super().onDeleteWidget() + self.clear() + + def commit(self): + """Commit the selected tree to output.""" + if len(self.scene.selectedItems()) == 0: + self.send('Tree', None) + # The selected tree index should only reset when model changes + if self.model is None: + self.selected_tree_index = -1 + return + + selected_item = self.scene.selectedItems()[0] + self.selected_tree_index = self.grid_items.index(selected_item) + tree = self.model.skl_model.estimators_[self.selected_tree_index] + + if self.forest_type == self.CLASSIFICATION: + obj = TreeClassifier(tree) + else: + obj = TreeRegressor(tree) + obj.domain = self.model.domain + obj.instances = self.model.instances + + obj.meta_target_class_index = self.target_class_index + obj.meta_size_calc_idx = self.size_calc_idx + obj.meta_size_log_scale = self.size_log_scale + obj.meta_depth_limit = self.depth_limit + + self.send('Tree', obj) + + def send_report(self): + self.report_plot() + + def _update_scene_rect(self): + self.scene.setSceneRect(self.scene.itemsBoundingRect()) + + def resizeEvent(self, ev): + width = (self.view.width() - self.view.verticalScrollBar().width()) + self.grid.reflow(width) + self.grid.setPreferredWidth(width) + + super().resizeEvent(ev) + + def _type_specific(self, method): + """A best effort method getter that somewhat separates logic specific + to classification and regression trees. + This relies on conventional naming of specific methods, e.g. + a method name _get_tooltip would need to be defined like so: + _classification_get_tooltip and _regression_get_tooltip, since they are + both specific. + + Parameters + ---------- + method : str + Method name that we would like to call. + + Returns + ------- + callable or None + + """ + if self.forest_type == self.CLASSIFICATION: + return getattr(self, '_classification' + method) + elif self.forest_type == self.REGRESSION: + return getattr(self, '_regression' + method) + else: + return None + + # CLASSIFICATION FOREST SPECIFIC METHODS + def _classification_update_target_class_combo(self): + self._clear_target_class_combo() + self.ui_target_class_combo.addItem('None') + values = [c.title() for c in + self.model.domain.class_vars[0].values] + self.ui_target_class_combo.addItems(values) + + def _classification_get_color_palette(self): + return [QtGui.QColor(*c) for c in self.model.domain.class_var.colors] + + def _classification_get_node_color(self, adapter, tree_node): + # this is taken almost directly from the existing classification tree + # viewer + colors = self.color_palette + distribution = adapter.get_distribution(tree_node.label)[0] + total = np.sum(distribution) + + if self.target_class_index: + p = distribution[self.target_class_index - 1] / total + color = colors[self.target_class_index - 1].light(200 - 100 * p) + else: + modus = np.argmax(distribution) + p = distribution[modus] / (total or 1) + color = colors[int(modus)].light(400 - 300 * p) + return color + + # REGRESSION FOREST SPECIFIC METHODS + def _regression_update_target_class_combo(self): + self._clear_target_class_combo() + self.ui_target_class_combo.addItems( + list(zip(*self.REGRESSION_COLOR_CALC))[0]) + self.ui_target_class_combo.setCurrentIndex(self.target_class_index) + + def _regression_get_color_palette(self): + return ContinuousPaletteGenerator( + *self.forest_adapter.domain.class_var.colors) + + def _regression_get_node_color(self, adapter, tree_node): + return self.REGRESSION_COLOR_CALC[self.target_class_index][1]( + adapter, tree_node + ) + + def _color_class_mean(self, adapter, tree_node): + # calculate node colors relative to the mean of the node samples + min_mean = np.min(self.clf_dataset.Y) + max_mean = np.max(self.clf_dataset.Y) + instances = adapter.get_instances_in_nodes(self.clf_dataset, tree_node) + mean = np.mean(instances.Y) + + return self.color_palette[(mean - min_mean) / (max_mean - min_mean)] + + def _color_stddev(self, adapter, tree_node): + # calculate node colors relative to the standard deviation in the node + # samples + min_mean, max_mean = 0, np.std(self.clf_dataset.Y) + instances = adapter.get_instances_in_nodes(self.clf_dataset, tree_node) + std = np.std(instances.Y) + + return self.color_palette[(std - min_mean) / (max_mean - min_mean)] + + +class GridItem(SelectableGridItem, ZoomableGridItem): + pass + + +class SklRandomForestAdapter: + def __init__(self, model, domain, adjust_weight=lambda x: x): + self._adapters = [] + + self._domain = domain + + self._trees = model.skl_model.estimators_ + self._domain = model.domain + self._adjust_weight = adjust_weight + + def get_trees(self): + if len(self._adapters) > 0: + return self._adapters + if len(self._trees) < 1: + return self._adapters + + self._adapters = [ + SklTreeAdapter(tree.tree_, self._domain, self._adjust_weight) + for tree in self._trees + ] + return self._adapters + + @property + def domain(self): + return self._domain diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py new file mode 100644 index 00000000000..42e79119ac2 --- /dev/null +++ b/Orange/widgets/visualize/pythagorastreeviewer.py @@ -0,0 +1,756 @@ +""" +Pythagoras tree viewer for visualizing tree structures. + +The pythagoras tree viewer widget is a widget that can be plugged into any +existing widget given a tree adapter instance. It is simply a canvas that takes +and input tree adapter and takes care of all the drawing. + +Types +----- +Square : namedtuple (center, length, angle) + Since Pythagoras trees deal only with squares (they also deal with + rectangles in the generalized form, but are completely unreadable), this + is what all the squares are stored as. +Point : namedtuple (x, y) + Self exaplanatory. + +""" +from collections import namedtuple, defaultdict, deque +from math import pi, sqrt, cos, sin, degrees + +from PyQt4 import QtCore, QtGui +from PyQt4.QtCore import Qt + +from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter + +# z index range, increase if needed +Z_STEP = 5000000 + +Square = namedtuple('Square', ['center', 'length', 'angle']) +Point = namedtuple('Point', ['x', 'y']) + + +class PythagorasTreeViewer(QtGui.QGraphicsWidget): + """Pythagoras tree viewer graphics widget. + + Simply pass in a tree adapter instance and a valid scene object, and the + pythagoras tree will be added. + + Examples + -------- + Pass tree through constructor. + >>> tree_view = PythagorasTreeViewer(parent=scene, adapter=tree_adapter) + + Pass tree later through method. + >>> tree_adapter = TreeAdapter() + >>> scene = QtGui.QGraphicsScene() + This is where the magic happens + >>> tree_view = PythagorasTreeViewer(parent=scene) + >>> tree_view.set_tree(tree_adapter) + + Both these examples set the appropriate tree and add all the squares to the + widget instance. + + Parameters + ---------- + parent : QGraphicsItem, optional + The parent object that the graphics widget belongs to. Should be a + scene. + adapter : TreeAdapter, optional + Any valid tree adapter instance. + interacitive : bool, optional, + Specify whether the widget should have an interactive display. This + means special hover effects, selectable boxes. Default is true. + + Notes + ----- + .. Note:: The class contains two clear methods: `clear` and `clear_tree`. + Each has their own use. + `clear_tree` will clear out the tree and remove any graphics items. + `clear` will, on the other hand, clear everything, all settings + (tooltip and color calculation functions. + + This is useful because when we want to change the size calculation of + the Pythagora tree, we just want to clear the scene and it would be + inconvenient to have to set color and tooltip functions again. + On the other hand, when we want to draw a brand new tree, it is best + to clear all settings to avoid any strange bugs - we start with a blank + slate. + + """ + + def __init__(self, parent=None, adapter=None, depth_limit=0, padding=0, + **kwargs): + super().__init__(parent) + + # Instance variables + # The tree adapter parameter will be handled at the end of init + self.tree_adapter = None + # The root tree node instance which is calculated inside the class + self._tree = None + self._padding = padding + + self.setSizePolicy(QtGui.QSizePolicy.Expanding, + QtGui.QSizePolicy.Expanding) + + # Necessary settings that need to be set from the outside + self._depth_limit = depth_limit + # Provide a nice green default in case no color function is provided + self.__calc_node_color_func = kwargs.get('node_color_func') + self.__get_tooltip_func = kwargs.get('tooltip_func') + self._interactive = kwargs.get('interactive', True) + + self._square_objects = {} + self._drawn_nodes = deque() + self._frontier = deque() + + # If a tree adapter was passed, set and draw the tree + if adapter is not None: + self.set_tree(adapter) + + def set_tree(self, tree_adapter): + """Pass in a new tree adapter instance and perform updates to canvas. + + Parameters + ---------- + tree_adapter : TreeAdapter + The new tree adapter that is to be used. + + Returns + ------- + + """ + self.clear_tree() + self.tree_adapter = tree_adapter + + if self.tree_adapter is not None: + self._tree = self._calculate_tree(self.tree_adapter) + self.set_depth_limit(tree_adapter.max_depth) + self._draw_tree(self._tree) + + def set_depth_limit(self, depth): + """Update the drawing depth limit. + + The drawing stops when the depth is GT the limit. This means that at + depth 0, the root node will be drawn. + + Parameters + ---------- + depth : int + The maximum depth at which the nodes can still be drawn. + + Returns + ------- + + """ + self._depth_limit = depth + self._draw_tree(self._tree) + + def set_node_color_func(self, func): + """Set the function that will be used to calculate the node colors. + + The function must accept one parameter that represents the label of a + given node and return the appropriate QColor object that should be used + for the node. + + Parameters + ---------- + func : Callable + func :: label -> QtGui.QColor + + Returns + ------- + + """ + if func != self._calc_node_color: + self.__calc_node_color_func = func + self._update_node_colors() + + def _calc_node_color(self, *args): + """Get the node color with a nice default fallback.""" + if self.__calc_node_color_func is not None: + return self.__calc_node_color_func(*args) + return QtGui.QColor('#297A1F') + + def set_tooltip_func(self, func): + """Set the function that will be used the get the node tooltips. + + Parameters + ---------- + func : Callable + func :: label -> str + + Returns + ------- + + """ + if func != self._get_tooltip: + self.__get_tooltip_func = func + self._update_node_tooltips() + + def _get_tooltip(self, *args): + """Get the node tooltip with a nice default fallback.""" + if self.__get_tooltip_func is not None: + return self.__get_tooltip_func(*args) + return 'Tooltip' + + def target_class_has_changed(self): + self._update_node_colors() + self._update_node_tooltips() + + def tooltip_has_changed(self): + self._update_node_tooltips() + + def _update_node_colors(self): + """Update all the node colors. + + Should be called when the color method is changed and the nodes need to + be drawn with the new colors. + + Returns + ------- + + """ + for square in self._squares(): + square.setBrush(self._calc_node_color(self.tree_adapter, + square.tree_node)) + + def _update_node_tooltips(self): + """Update all the tooltips for the squares.""" + for square in self._squares(): + square.setToolTip(self._get_tooltip(square.tree_node)) + + def clear(self): + """Clear the entire widget state.""" + self.__calc_node_color_func = None + self.__get_tooltip_func = None + self.clear_tree() + + def clear_tree(self): + """Clear only the tree, keeping tooltip and color functions.""" + self.tree_adapter = None + self._tree = None + self._clear_scene() + + @staticmethod + def _calculate_tree(tree_adapter): + """Actually calculate the tree squares""" + tree_builder = PythagorasTree() + return tree_builder.pythagoras_tree( + tree_adapter, tree_adapter.root, Square(Point(0, 0), 200, -pi / 2) + ) + + def _draw_tree(self, root): + """Efficiently draw the tree with regards to the depth. + + If using a recursive approach, the tree had to be redrawn every time + the depth was changed, which was very impractical for larger trees, + since everything got very slow, very fast. + + In this approach, we use two queues to represent the tree frontier and + the nodes that have already been drawn. We also store the depth. This + way, when the max depth is increased, we do not redraw the whole tree + but only iterate throught the frontier and draw those nodes, and update + the frontier accordingly. + When decreasing the max depth, we reverse the process, we clear the + frontier, and remove nodes from the drawn nodes, and append those with + depth max_depth + 1 to the frontier, so the frontier doesn't get + cluttered. + + Parameters + ---------- + root : TreeNode + The root tree node. + + Returns + ------- + + """ + if self._tree is None: + return + # if this is the first time drawing the tree begin with root + if not self._drawn_nodes: + self._frontier.appendleft((0, root)) + # if the depth was decreased, we can clear the frontier, otherwise + # frontier gets cluttered with non-frontier nodes + was_decreased = self._depth_was_decreased() + if was_decreased: + self._frontier.clear() + # remove nodes from drawn and add to frontier if limit is decreased + while self._drawn_nodes: + depth, node = self._drawn_nodes.pop() + # check if the node is in the allowed limit + if depth <= self._depth_limit: + self._drawn_nodes.append((depth, node)) + break + if depth == self._depth_limit + 1: + self._frontier.appendleft((depth, node)) + + if node.label in self._square_objects: + self._square_objects[node.label].hide() + + # add nodes to drawn and remove from frontier if limit is increased + while self._frontier: + depth, node = self._frontier.popleft() + # check if the depth of the node is outside the allowed limit + if depth > self._depth_limit: + self._frontier.appendleft((depth, node)) + break + self._drawn_nodes.append((depth, node)) + self._frontier.extend((depth + 1, c) for c in node.children) + + if node.label in self._square_objects: + self._square_objects[node.label].show() + else: + square_obj = InteractiveSquareGraphicsItem \ + if self._interactive else SquareGraphicsItem + self._square_objects[node.label] = square_obj( + node, + parent=self, + brush=QtGui.QBrush( + self._calc_node_color(self.tree_adapter, node) + ), + tooltip=self._get_tooltip(node), + zvalue=depth, + ) + + def _depth_was_decreased(self): + if not self._drawn_nodes: + return False + # checks if the max depth was increased from the last change + depth, node = self._drawn_nodes.pop() + self._drawn_nodes.append((depth, node)) + # if the right most node in drawn nodes has appropriate depth, it must + # have been increased + return depth > self._depth_limit + + def _squares(self): + return [node.graphics_item for _, node in self._drawn_nodes] + + def _clear_scene(self): + for square in self._squares(): + self.scene().removeItem(square) + self._frontier.clear() + self._drawn_nodes.clear() + self._square_objects.clear() + + def boundingRect(self): + return self.childrenBoundingRect().adjusted( + -self._padding, -self._padding, self._padding, self._padding) + + def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): + return self.boundingRect().size() + \ + QtCore.QSizeF(self._padding, self._padding) + + +class SquareGraphicsItem(QtGui.QGraphicsRectItem): + """Square Graphics Item. + + Square component to draw as components for the non-interactive Pythagoras + tree. + + Parameters + ---------- + tree_node : TreeNode + The tree node the square represents. + brush : QColor, optional + The brush to be used as the backgound brush. + pen : QPen, optional + The pen to be used for the border. + + """ + + def __init__(self, tree_node, parent=None, **kwargs): + self.tree_node = tree_node + self.tree_node.graphics_item = self + + center, length, angle = tree_node.square + self._center_point = center + self.center = QtCore.QPointF(*center) + self.length = length + self.angle = angle + super().__init__(self._get_rect_attributes(), parent) + self.setTransformOriginPoint(self.boundingRect().center()) + self.setRotation(degrees(angle)) + + self.setBrush(kwargs.get('brush', QtGui.QColor('#297A1F'))) + self.setPen(kwargs.get('pen', QtGui.QPen(QtGui.QColor('#000')))) + + self.setAcceptHoverEvents(True) + self.setZValue(kwargs.get('zvalue', 0)) + self.z_step = Z_STEP + + # calculate the correct z values based on the parent + if self.tree_node.parent != -1: + p = self.tree_node.parent + # override root z step + num_children = len(p.children) + own_index = [1 if c.label == self.tree_node.label else 0 + for c in p.children].index(1) + + self.z_step = int(p.graphics_item.z_step / num_children) + base_z = p.graphics_item.zValue() + + self.setZValue(base_z + own_index * self.z_step) + + def _get_rect_attributes(self): + """Get the rectangle attributes requrired to draw item. + + Compute the QRectF that a QGraphicsRect needs to be rendered with the + data passed down in the constructor. + + """ + height = width = self.length + x = self.center.x() - self.length / 2 + y = self.center.y() - self.length / 2 + return QtCore.QRectF(x, y, height, width) + + +class InteractiveSquareGraphicsItem(SquareGraphicsItem): + """Interactive square graphics items. + + This is different from the base square graphics item so that it is + selectable, and it can handle and react to hover events (highlight and + focus own branch). + + Parameters + ---------- + tree_node : TreeNode + The tree node the square represents. + brush : QColor, optional + The brush to be used as the backgound brush. + pen : QPen, optional + The pen to be used for the border. + + """ + + timer = QtCore.QTimer() + + MAX_OPACITY = 1. + SELECTION_OPACITY = .5 + HOVER_OPACITY = .1 + + def __init__(self, tree_node, parent=None, **kwargs): + super().__init__(tree_node, parent, **kwargs) + self.setFlag(QtGui.QGraphicsItem.ItemIsSelectable, True) + + self.initial_zvalue = self.zValue() + # The max z value changes if any item is selected + self.any_selected = False + + self.setToolTip(kwargs.get('tooltip', 'Tooltip')) + + self.timer.setSingleShot(True) + + def hoverEnterEvent(self, ev): + self.timer.stop() + + def fnc(graphics_item): + graphics_item.setZValue(Z_STEP) + if self.any_selected: + if graphics_item.isSelected(): + opacity = self.MAX_OPACITY + else: + opacity = self.SELECTION_OPACITY + else: + opacity = self.MAX_OPACITY + graphics_item.setOpacity(opacity) + + def other_fnc(graphics_item): + if graphics_item.isSelected(): + opacity = self.MAX_OPACITY + else: + opacity = self.HOVER_OPACITY + graphics_item.setOpacity(opacity) + graphics_item.setZValue(self.initial_zvalue) + + self._propagate_z_values(self, fnc, other_fnc) + + def hoverLeaveEvent(self, ev): + + def fnc(graphics_item): + # No need to set opacity in this branch since it was just selected + # and had the max value + graphics_item.setZValue(self.initial_zvalue) + + def other_fnc(graphics_item): + if self.any_selected: + if graphics_item.isSelected(): + opacity = self.MAX_OPACITY + else: + opacity = self.SELECTION_OPACITY + else: + opacity = self.MAX_OPACITY + graphics_item.setOpacity(opacity) + + self.timer.timeout.connect( + lambda: self._propagate_z_values(self, fnc, other_fnc)) + + self.timer.start(250) + + def _propagate_z_values(self, graphics_item, fnc, other_fnc): + self._propagate_to_children(graphics_item, fnc) + self._propagate_to_parents(graphics_item, fnc, other_fnc) + + def _propagate_to_children(self, graphics_item, fnc): + # propagate function that handles graphics item to appropriate children + fnc(graphics_item) + for c in graphics_item.tree_node.children: + self._propagate_to_children(c.graphics_item, fnc) + + def _propagate_to_parents(self, graphics_item, fnc, other_fnc): + # propagate function that handles graphics item to appropriate parents + if graphics_item.tree_node.parent != -1: + parent = graphics_item.tree_node.parent.graphics_item + # handle the non relevant children nodes + for c in parent.tree_node.children: + if c != graphics_item.tree_node: + self._propagate_to_children(c.graphics_item, other_fnc) + # handle the parent node + fnc(parent) + # propagate up the tree + self._propagate_to_parents(parent, fnc, other_fnc) + + def selection_changed(self): + # Handle selection changed + self.any_selected = len(self.scene().selectedItems()) > 0 + if self.any_selected: + if self.isSelected(): + self.setOpacity(self.MAX_OPACITY) + else: + if self.opacity() != self.HOVER_OPACITY: + self.setOpacity(self.SELECTION_OPACITY) + else: + self.setGraphicsEffect(None) + self.setOpacity(self.MAX_OPACITY) + + def paint(self, painter, option, widget=None): + # Override the default selected appearance + if self.isSelected(): + option.state ^= QtGui.QStyle.State_Selected + rect = self.rect() + # this must render before overlay due to order in which it's drawn + super().paint(painter, option, widget) + painter.save() + pen = QtGui.QPen(QtGui.QColor(Qt.black)) + pen.setWidth(4) + pen.setJoinStyle(Qt.MiterJoin) + painter.setPen(pen) + painter.drawRect(rect.adjusted(2, 2, -2, -2)) + painter.restore() + else: + super().paint(painter, option, widget) + + +class TreeNode: + """A node in the tree structure used to represent the tree adapter + + Parameters + ---------- + label : int + The label of the tree node, can be looked up in the original tree. + square : Square + The square the represents the tree node. + parent : TreeNode or object + The parent of the current node. In the case of root, an object + containing the root label of the tree adapter should be passed. + children : tuple of TreeNode, optional + All the children that belong to this node. + + """ + + def __init__(self, label, square, parent, children=()): + self.label = label + self.square = square + self.parent = parent + self.children = children + self.graphics_item = None + + def __str__(self): + return '({}) -> [{}]'.format(self.parent, self.label) + + +class PythagorasTree: + """Pythagoras tree. + + Contains all the logic that converts a given tree adapter to a tree + consisting of node classes. + + """ + + def __init__(self): + # store the previous angles of each square children so that slopes can + # be computed + self._slopes = defaultdict(list) + + def pythagoras_tree(self, tree, node, square): + """Get the Pythagoras tree representation in a graph like view. + + Constructs a graph using TreeNode into a tree structure. Each node in + graph contains the information required to plot the the tree. + + Parameters + ---------- + tree : TreeAdapter + A tree adapter instance where the original tree is stored. + node : int + The node label, the root node is denoted with 0. + square : Square + The initial square which will represent the root of the tree. + + Returns + ------- + TreeNode + The root node which contains the rest of the tree. + + """ + # make sure to clear out any old slopes if we are drawing a new tree + if node == tree.root: + self._slopes.clear() + + children = tuple( + self._compute_child(tree, square, child) + for child in tree.children(node) + ) + # make sure to pass a reference to parent to each child + obj = TreeNode(node, square, tree.parent(node), children) + # mutate the existing data stored in the created tree node + for c in children: + c.parent = obj + return obj + + def _compute_child(self, tree, parent_square, node): + """Compute all the properties for a single child. + + Parameters + ---------- + tree : TreeAdapter + A tree adapter instance where the original tree is stored. + parent_square : Square + The parent square of the given child. + node : int + The node label of the child. + + Returns + ------- + TreeNode + The tree node representation of the given child with the computed + subtree. + + """ + weight = tree.weight(node) + # the angle of the child from its parent + alpha = weight * pi + # the child side length + length = parent_square.length * sin(alpha / 2) + # the sum of the previous anlges + prev_angles = sum(self._slopes[parent_square]) + + center = self._compute_center( + parent_square, length, alpha, prev_angles + ) + # the angle of the square is dependent on the parent, the current + # angle and the previous angles. Subtract PI/2 so it starts drawing at + # 0rads. + angle = parent_square.angle - pi / 2 + prev_angles + alpha / 2 + square = Square(center, length, angle) + + self._slopes[parent_square].append(alpha) + + return self.pythagoras_tree(tree, node, square) + + def _compute_center(self, initial_square, length, alpha, base_angle=0): + """Compute the central point of a child square. + + Parameters + ---------- + initial_square : Square + The parent square representation where we will be drawing from. + length : float + The length of the side of the new square (the one we are computing + the center for). + alpha : float + The angle that defines the size of our new square (in radians). + base_angle : float, optional + If the square we want to find the center for is not the first child + i.e. its edges does not touch the base square, then we need the + initial angle that will act as the starting point for the new + square. + + Returns + ------- + Point + The central point to the new square. + + """ + parent_center, parent_length, parent_angle = initial_square + # get the point on the square side that will be the rotation origin + t0 = self._get_point_on_square_edge( + parent_center, parent_length, parent_angle) + # get the edge point that we will rotate around t0 + square_diagonal_length = sqrt(2 * parent_length ** 2) + edge = self._get_point_on_square_edge( + parent_center, square_diagonal_length, parent_angle - pi / 4) + # if the new square is not the first child, we need to rotate the edge + if base_angle != 0: + edge = self._rotate_point(edge, t0, base_angle) + + # rotate the edge point to the correct spot + t1 = self._rotate_point(edge, t0, alpha) + + # calculate the middle point between the rotated point and edge + t2 = Point((t1.x + edge.x) / 2, (t1.y + edge.y) / 2) + # calculate the slope of the new square + slope = parent_angle - pi / 2 + alpha / 2 + # using this data, we can compute the square center + return self._get_point_on_square_edge(t2, length, slope + base_angle) + + @staticmethod + def _rotate_point(point, around, alpha): + """Rotate a point around another point by some angle. + + Parameters + ---------- + point : Point + The point to rotate. + around : Point + The point to perform rotation around. + alpha : float + The angle to rotate by (in radians). + + Returns + ------- + Point: + The rotated point. + + """ + temp = Point(point.x - around.x, point.y - around.y) + temp = Point( + temp.x * cos(alpha) - temp.y * sin(alpha), + temp.x * sin(alpha) + temp.y * cos(alpha) + ) + return Point(temp.x + around.x, temp.y + around.y) + + @staticmethod + def _get_point_on_square_edge(center, length, angle): + """Calculate the central point on the drawing edge of the given square. + + Parameters + ---------- + center : Point + The square center point. + length : float + The square side length. + angle : float + The angle of the square. + + Returns + ------- + Point + A point on the center of the drawing edge of the given square. + + """ + return Point( + center.x + length / 2 * cos(angle), + center.y + length / 2 * sin(angle) + ) diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py new file mode 100644 index 00000000000..c0f7d05ed40 --- /dev/null +++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py @@ -0,0 +1,59 @@ +import math +import unittest +import Orange.widgets + +from Orange.widgets.visualize.pythagorastreeviewer import ( + PythagorasTree, + Point, + Square, +) + + +class TestPythagorasTree(unittest.TestCase): + def setUp(self): + self.builder = PythagorasTree() + + def test_get_point_on_square_edge_with_no_angle(self): + point = self.builder._get_point_on_square_edge( + center=Point(0, 0), length=2, angle=0 + ) + expected_point = Point(1, 0) + self.assertAlmostEqual(point.x, expected_point.x, places=1) + self.assertAlmostEqual(point.y, expected_point.y, places=1) + + def test_get_point_on_square_edge_with_non_zero_angle(self): + point = self.builder._get_point_on_square_edge( + center=Point(2.7, 2.77), length=1.65, angle=math.radians(20.97) + ) + expected_point = Point(3.48, 3.07) + self.assertAlmostEqual(point.x, expected_point.x, places=1) + self.assertAlmostEqual(point.y, expected_point.y, places=1) + + def test_compute_center_with_simple_square_angle(self): + initial_square = Square(Point(0, 0), length=2, angle=math.pi / 2) + point = self.builder._compute_center( + initial_square, length=1.13, alpha=math.radians(68.57)) + expected_point = Point(1.15, 1.78) + self.assertAlmostEqual(point.x, expected_point.x, places=1) + self.assertAlmostEqual(point.y, expected_point.y, places=1) + + def test_compute_center_with_complex_square_angle(self): + initial_square = Square( + Point(1.5, 1.5), length=2.24, angle=math.radians(63.43) + ) + point = self.builder._compute_center( + initial_square, length=1.65, alpha=math.radians(95.06)) + expected_point = Point(3.48, 3.07) + self.assertAlmostEqual(point.x, expected_point.x, places=1) + self.assertAlmostEqual(point.y, expected_point.y, places=1) + + def test_compute_center_with_complex_square_angle_with_base_angle(self): + initial_square = Square( + Point(1.5, 1.5), length=2.24, angle=math.radians(63.43) + ) + point = self.builder._compute_center( + initial_square, length=1.51, alpha=math.radians(180 - 95.06), + base_angle=math.radians(95.06)) + expected_point = Point(1.43, 3.98) + self.assertAlmostEqual(point.x, expected_point.x, places=1) + self.assertAlmostEqual(point.y, expected_point.y, places=1) diff --git a/Orange/widgets/visualize/widgetutils/__init__.py b/Orange/widgets/visualize/widgetutils/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Orange/widgets/visualize/widgetutils/common/__init__.py b/Orange/widgets/visualize/widgetutils/common/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Orange/widgets/visualize/widgetutils/common/owgrid.py b/Orange/widgets/visualize/widgetutils/common/owgrid.py new file mode 100644 index 00000000000..6e091288177 --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/common/owgrid.py @@ -0,0 +1,280 @@ +from itertools import zip_longest + +from PyQt4 import QtGui, QtCore +from PyQt4.QtCore import Qt + + +class GridItem(QtGui.QGraphicsWidget): + """The base class for grid items, takes care of positioning in grid. + + Parameters + ---------- + widget : QtGui.QGraphicsWidget + parent : QtGui.QGraphicsWidget + + See Also + -------- + OWGrid + SelectableGridItem + ZoomableGridItem + + """ + + def __init__(self, widget, parent=None, **_): + super().__init__(parent) + # For some reason, the super constructor is not setting the parent + self.setParent(parent) + + self.widget = widget + if hasattr(self.widget, 'setParent'): + self.widget.setParentItem(self) + self.widget.setParent(self) + + # Move the child widget to (0, 0) so that bounding rects match up + # This is needed because the bounding rect is caluclated with the size + # hint from (0, 0), regardless of any method override + rect = self.widget.boundingRect() + self.widget.moveBy(-rect.topLeft().x(), -rect.topLeft().y()) + + def boundingRect(self): + return QtCore.QRectF(QtCore.QPointF(0, 0), + self.widget.boundingRectoundingRect().size()) + + def sizeHint(self, size_hint, size_constraint=None, **kwargs): + return self.widget.sizeHint(size_hint, size_constraint, **kwargs) + + +class SelectableGridItem(GridItem): + """Makes a grid item selectable. + + Parameters + ---------- + widget : QtGui.QGraphicsWidget + parent : QtGui.QgraphicsWidget + + See Also + -------- + OWGrid + GridItem + ZoomableGridItem + + """ + + def __init__(self, widget, parent=None, **kwargs): + super().__init__(widget, parent, **kwargs) + + self.setFlags(QtGui.QGraphicsWidget.ItemIsSelectable) + + def paint(self, painter, options, widget=None): + super().paint(painter, options, widget) + rect = self.boundingRect() + painter.save() + if self.isSelected(): + painter.setPen(QtGui.QPen(QtGui.QColor(125, 162, 206, 192))) + painter.setBrush(QtGui.QBrush(QtGui.QColor(217, 232, 252, 192))) + painter.drawRoundedRect(QtCore.QRectF( + rect.topLeft(), self.geometry().size()), 3, 3) + else: + painter.setPen(QtGui.QPen(QtGui.QColor('#ebebeb'))) + painter.drawRoundedRect(QtCore.QRectF( + rect.topLeft(), self.geometry().size()), 3, 3) + painter.restore() + + +class ZoomableGridItem(GridItem): + """Makes a grid item "zoomable" through the `set_max_size` method. + + Notes + ----- + .. Note:: This grid item will override any bounding box or size hint + defined in the class hierarchy with its own. + .. Note:: This makes the grid item square. + + Parameters + ---------- + widget : QtGui.QGraphicsWidget + parent : QtGui.QGraphicsWidget + max_size : int + The maximum size of the grid item. + + See Also + -------- + OWGrid + GridItem + SelectableGridItem + + """ + + def __init__(self, widget, parent=None, max_size=150, **kwargs): + self._max_size = QtCore.QSizeF(max_size, max_size) + # We store the offsets from the top left corner to move widget properly + self.__offset_x = self.__offset_y = 0 + + super().__init__(widget, parent, **kwargs) + + self._resize_widget() + + def set_max_size(self, max_size): + self.widget.resetTransform() + self._max_size = QtCore.QSizeF(max_size, max_size) + self._resize_widget() + + def _resize_widget(self): + w = self.widget + own_hint = self.sizeHint(Qt.PreferredSize) + + scale_w = own_hint.width() / w.boundingRect().width() + scale_h = own_hint.height() / w.boundingRect().height() + scale = scale_w if scale_w < scale_h else scale_h + + # Move the widget back to origin then perfom transformations + self.widget.moveBy(-self.__offset_x, -self.__offset_y) + # Move the tranform origin to top left, so it stays in place when + # scaling + w.setTransformOriginPoint(w.boundingRect().topLeft()) + w.setScale(scale) + # Then, move the scaled widget to the center of the bounding box + own_rect = self.boundingRect() + self.__offset_x = (own_rect.width() - w.boundingRect().width() * + scale) / 2 + self.__offset_y = (own_rect.height() - w.boundingRect().height() * + scale) / 2 + self.widget.moveBy(self.__offset_x, self.__offset_y) + # Finally, tell the world you've changed + self.updateGeometry() + + def boundingRect(self): + return QtCore.QRectF(QtCore.QPointF(0, 0), self._max_size) + + def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): + return self._max_size + + +class OWGrid(QtGui.QGraphicsWidget): + """Responsive grid layout widget. + + Manages grid items for various window sizes. + + Accepts grid items as items. + + Parameters + ---------- + parent : QtGui.QGraphicsWidget + + Examples + -------- + >>> grid = OWGrid() + + It's a good idea to define what you want your grid items to do. For this + example, we will make them selectable and zoomable, so we define a class + that inherits from both: + >>> class MyGridItem(SelectableGridItem, ZoomableGridItem): + >>> pass + + We then take a list of items and wrap them into our new `MyGridItem` + instances. + >>> items = [QtGui.QGraphicsRectItem(0, 0, 10, 10)] + >>> grid_items = [MyGridItem(i, grid) for i in items] + + We can then set the items to be displayed + >>> grid.set_items(grid_items) + + """ + + def __init__(self, parent=None): + super().__init__(parent) + + self.setSizePolicy(QtGui.QSizePolicy.Maximum, + QtGui.QSizePolicy.Maximum) + self.setContentsMargins(10, 10, 10, 10) + + self.__layout = QtGui.QGraphicsGridLayout() + self.__layout.setContentsMargins(0, 0, 0, 0) + self.__layout.setSpacing(10) + self.setLayout(self.__layout) + + def set_items(self, items): + for i, item in enumerate(items): + # Place the items in some arbitrary order - they will be rearranged + # before user sees this ordering + self.__layout.addItem(item, i, 0) + + def setGeometry(self, rect): + super().setGeometry(rect) + self.reflow(self.size().width()) + + def reflow(self, width): + """Recalculate the layout and reposition the elements so they fit. + + Parameters + ---------- + width : int + The maximum width of the grid. + + Returns + ------- + + """ + # When setting the geometry when opened, the layout doesn't yet exist + if self.layout() is None: + return + + grid = self.__layout + + left, right, *_ = self.getContentsMargins() + width -= left + right + + # Get size hints with 32 as the minimum size for each cell + widths = [max(64, h.width()) for h in self._hints(Qt.PreferredSize)] + ncol = self._fit_n_cols(widths, grid.horizontalSpacing(), width) + + # The number of columns is already optimal + if ncol == grid.columnCount(): + return + + # remove all items from the layout, then re-add them back in updated + # positions + items = self._items() + + for item in items: + grid.removeItem(item) + + for i, item in enumerate(items): + grid.addItem(item, i // ncol, i % ncol) + grid.setAlignment(item, Qt.AlignCenter) + + def clear(self): + for item in self._items(): + self.__layout.removeItem(item) + item.setParent(None) + + @staticmethod + def _fit_n_cols(widths, spacing, constraint): + + def sliced(seq, n_col): + """Slice the widths into n lists that contain their respective + widths. E.g. [5, 5, 5], 2 => [[5, 5], [5]]""" + return [seq[i:i + n_col] for i in range(0, len(seq), n_col)] + + def flow_width(widths, spacing, ncol): + w = sliced(widths, ncol) + col_widths = map(max, zip_longest(*w, fillvalue=0)) + return sum(col_widths) + (ncol - 1) * spacing + + ncol_best = 1 + for ncol in range(2, len(widths) + 1): + width = flow_width(widths, spacing, ncol) + if width <= constraint: + ncol_best = ncol + else: + break + + return ncol_best + + def _items(self): + if not self.__layout: + return [] + return [self.__layout.itemAt(i) for i in range(self.__layout.count())] + + def _hints(self, which): + return [item.sizeHint(which) for item in self._items()] diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/widgetutils/common/owlegend.py new file mode 100644 index 00000000000..1047d7a43b0 --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/common/owlegend.py @@ -0,0 +1,644 @@ +""" +Legend classes to use with `QGraphicsScene` objects. +""" +import numpy as np +from PyQt4 import QtGui, QtCore +from PyQt4.QtCore import Qt + + +class Anchorable(QtGui.QGraphicsWidget): + + __corners = ['topLeft', 'topRight', 'bottomLeft', 'bottomRight'] + TOP_LEFT, TOP_RIGHT, BOTTOM_LEFT, BOTTOM_RIGHT = __corners + + def __init__(self, parent=None, corner='bottomRight', offset=(10, 10)): + super().__init__(parent) + + self.__corner_str = corner if corner in self.__corners else None + # The flag indicates whether or not the item has been drawn on yet. + # This is useful for determining the initial offset, due to the fact + # that dimensions are available in the resize event, which can occur + # multiple times. + self.__has_been_drawn = False + + if isinstance(offset, tuple) or isinstance(offset, list): + assert len(offset) == 2 + self.__offset = QtCore.QPoint(*offset) + elif isinstance(offset, QtCore.QPoint): + self.__offset = offset + + def moveEvent(self, ev): + super().moveEvent(ev) + # This check is needed because simply resizing the window will cause + # the item to move and trigger a `moveEvent` therefore we need to check + # that the movement was done intentionally by the user using the mouse + if QtGui.QApplication.mouseButtons() == Qt.LeftButton: + self.recalculate_offset() + + def resizeEvent(self, ev): + # When the item is first shown, we need to update its position + super().resizeEvent(ev) + if not self.__has_been_drawn: + self.__offset = self.__calculate_actual_offset(self.__offset) + self.update_pos() + self.__has_been_drawn = True + + def showEvent(self, ev): + # When the item is first shown, we need to update its position + super().showEvent(ev) + self.update_pos() + + def recalculate_offset(self): + # This is called whenever the item is being moved and needs to + # recalculate its offset + view = self.__get_view() + # Get the view box and position of legend relative to the view, + # not the scene + pos = view.mapFromScene(self.pos()) + view_box = self.__usable_viewbox() + + self.__corner_str = self.__get_closest_corner() + viewbox_corner = getattr(view_box, self.__corner_str)() + + self.__offset = viewbox_corner - pos + + def update_pos(self): + # This is called whenever something happened with the view that caused + # this item to move from its anchored position, so we have to adjust + # the position to maintain the effect of being anchored + view = self.__get_view() + if self.__corner_str and view is not None: + box = self.__usable_viewbox() + corner = getattr(box, self.__corner_str)() + new_pos = corner - self.__offset + self.setPos(view.mapToScene(new_pos)) + + def __calculate_actual_offset(self, offset): + """Take the offset specified in the constructor and calculate the + actual offset from the top left corner of the item so positioning can + be done correctly.""" + off_x, off_y = offset.x(), offset.y() + w, h = self.boundingRect().width(), self.boundingRect().height() + if self.__corner_str == self.TOP_LEFT: + return QtCore.QPoint(-off_x, -off_y) + elif self.__corner_str == self.TOP_RIGHT: + return QtCore.QPoint(off_x + w, -off_y) + elif self.__corner_str == self.BOTTOM_RIGHT: + return QtCore.QPoint(off_x + w, off_y + h) + elif self.__corner_str == self.BOTTOM_LEFT: + return QtCore.QPoint(-off_x, off_y + h) + + def __get_closest_corner(self): + view = self.__get_view() + # Get the view box and position of legend relative to the view, + # not the scene + pos = view.mapFromScene(self.pos()) + legend_box = QtCore.QRect(pos, self.size().toSize()) + view_box = QtCore.QRect(QtCore.QPoint(0, 0), view.size()) + + def distance(t1, t2): + # 2d euclidean distance + return np.sqrt((t1.x() - t2.x()) ** 2 + (t1.y() - t2.y()) ** 2) + + distances = [ + (distance(getattr(view_box, corner)(), + getattr(legend_box, corner)()), corner) + for corner in self.__corners + ] + _, corner = min(distances) + return corner + + def __get_own_corner(self): + view = self.__get_view() + pos = view.mapFromScene(self.pos()) + legend_box = QtCore.QRect(pos, self.size().toSize()) + return getattr(legend_box, self.__corner_str)() + + def __get_view(self): + if self.scene() is not None: + view, = self.scene().views() + return view + else: + return None + + def __usable_viewbox(self): + view = self.__get_view() + + if view.horizontalScrollBar().isVisible(): + h = view.horizontalScrollBar().size().height() + else: + h = 0 + + if view.verticalScrollBar().isVisible(): + w = view.verticalScrollBar().size().width() + else: + w = 0 + + size = view.size() - QtCore.QSize(w, h) + return QtCore.QRect(QtCore.QPoint(0, 0), size) + + +class AnchorableGraphicsView(QtGui.QGraphicsView): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.horizontalScrollBar().valueChanged.connect( + self.update_anchored_items) + self.verticalScrollBar().valueChanged.connect( + self.update_anchored_items) + + def resizeEvent(self, ev): + super().resizeEvent(ev) + self.update_anchored_items() + + def mousePressEvent(self, ev): + super().mousePressEvent(ev) + self.update_anchored_items() + + def wheelEvent(self, ev): + super().wheelEvent(ev) + self.update_anchored_items() + + def mouseMoveEvent(self, ev): + super().mouseMoveEvent(ev) + self.update_anchored_items() + + def update_anchored_items(self): + for item in self.__anchorable_items(): + item.update_pos() + + def __anchorable_items(self): + return [i for i in self.scene().items() if isinstance(i, Anchorable)] + + +class ColorIndicator(QtGui.QGraphicsWidget): + pass + + +class LegendItemSquare(ColorIndicator): + """Legend square item. + + The legend square item is a small colored square image that can be plugged + into the legend in front of the text object. + + This should only really be used in conjunction with ˙LegendItem˙. + + Parameters + ---------- + color : QtGui.QColor + The color of the square. + parent : QtGui.QGraphicsItem + + See Also + -------- + LegendItemCircle + + """ + + SIZE = QtCore.QSizeF(12, 12) + + def __init__(self, color, parent): + super().__init__(parent) + + height, width = self.SIZE.height(), self.SIZE.width() + self.__square = QtGui.QGraphicsRectItem(0, 0, height, width) + self.__square.setBrush(QtGui.QBrush(color)) + self.__square.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0))) + self.__square.setParentItem(self) + + def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): + return QtCore.QSizeF(self.__square.boundingRect().size()) + + +class LegendItemCircle(ColorIndicator): + """Legend circle item. + + The legend circle item is a small colored circle image that can be plugged + into the legend in front of the text object. + + This should only really be used in conjunction with ˙LegendItem˙. + + Parameters + ---------- + color : QtGui.QColor + The color of the square. + parent : QtGui.QGraphicsItem + + See Also + -------- + LegendItemSquare + + """ + + SIZE = QtCore.QSizeF(12, 12) + + def __init__(self, color, parent): + super().__init__(parent) + + height, width = self.SIZE.height(), self.SIZE.width() + self.__circle = QtGui.QGraphicsEllipseItem(0, 0, height, width) + self.__circle.setBrush(QtGui.QBrush(color)) + self.__circle.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0))) + self.__circle.setParentItem(self) + + def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): + return QtCore.QSizeF(self.__circle.boundingRect().size()) + + +class LegendItemTitle(QtGui.QGraphicsWidget): + """Legend item title - the text displayed in the legend. + + This should only really be used in conjunction with ˙LegendItem˙. + + Parameters + ---------- + text : str + parent : QtGui.QGraphicsItem + font : QtGui.QFont + This + + """ + + def __init__(self, text, parent, font): + super().__init__(parent) + + self.__text = QtGui.QGraphicsTextItem(text.title()) + self.__text.setParentItem(self) + self.__text.setFont(font) + + def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): + return QtCore.QSizeF(self.__text.boundingRect().size()) + + +class LegendItem(QtGui.QGraphicsLinearLayout): + """Legend item - one entry in the legend. + + This represents one entry in the legend i.e. a color indicator and the text + beside it. + + Parameters + ---------- + color : QtGui.QColor + The color that the entry will represent. + title : str + The text that will be displayed for the color. + parent : QtGui.QGraphicsItem + color_indicator_cls : ColorIndicator + The type of `ColorIndicator` that will be used for the color. + font : QtGui.QFont, optional + + """ + + def __init__(self, color, title, parent, color_indicator_cls, font=None): + super().__init__() + + self.__parent = parent + self.__color_indicator = color_indicator_cls(color, parent) + self.__title_label = LegendItemTitle(title, parent, font=font) + + self.addItem(self.__color_indicator) + self.addItem(self.__title_label) + + # Make sure items are aligned properly, since the color box and text + # won't be the same height. + self.setAlignment(self.__color_indicator, Qt.AlignCenter) + self.setAlignment(self.__title_label, Qt.AlignCenter) + self.setContentsMargins(0, 0, 0, 0) + self.setSpacing(5) + + +class LegendGradient(QtGui.QGraphicsWidget): + """Gradient widget. + + A gradient square bar that can be used to display continuous values. + + Parameters + ---------- + palette : iterable[QtGui.QColor] + parent : QtGui.QGraphicsWidget + orientation : Qt.Orientation + + Notes + ----- + .. Note:: While the gradient does support any number of colors, any more + than 3 is not very readable. This should not be a problem, since Orange + only implements 2 or 3 colors. + + """ + + # Default sizes (assume gradient is vertical by default) + GRADIENT_WIDTH = 20 + GRADIENT_HEIGHT = 150 + + def __init__(self, palette, parent, orientation): + super().__init__(parent) + + self.__gradient = QtGui.QLinearGradient() + num_colors = len(palette) + for idx, stop in enumerate(palette): + self.__gradient.setColorAt(idx * (1. / (num_colors - 1)), stop) + + # We need to tell the gradient where it's start and stop points are + self.__gradient.setStart(QtCore.QPointF(0, 0)) + if orientation == Qt.Vertical: + final_stop = QtCore.QPointF(0, self.GRADIENT_HEIGHT) + else: + final_stop = QtCore.QPointF(self.GRADIENT_HEIGHT, 0) + self.__gradient.setFinalStop(final_stop) + + # Get the appropriate rectangle dimensions based on orientation + if orientation == Qt.Vertical: + w, h = self.GRADIENT_WIDTH, self.GRADIENT_HEIGHT + elif orientation == Qt.Horizontal: + w, h = self.GRADIENT_HEIGHT, self.GRADIENT_WIDTH + + self.__rect_item = QtGui.QGraphicsRectItem(0, 0, w, h, self) + self.__rect_item.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0))) + self.__rect_item.setBrush(QtGui.QBrush(self.__gradient)) + + def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs): + return QtCore.QSizeF(self.__rect_item.boundingRect().size()) + + +class ContinuousLegendItem(QtGui.QGraphicsLinearLayout): + """Continuous legend item. + + Contains a gradient bar with the color ranges, as well as two labels - one + on each side of the gradient bar. + + Parameters + ---------- + palette : iterable[QtGui.QColor] + values : iterable[float...] + The number of values must match the number of colors in passed in the + color palette. + parent : QtGui.QGraphicsWidget + font : QtGui.QFont + orientation : Qt.Orientation + + """ + + def __init__(self, palette, values, parent, font=None, + orientation=Qt.Vertical): + if orientation == Qt.Vertical: + super().__init__(Qt.Horizontal) + else: + super().__init__(Qt.Vertical) + + self.__parent = parent + self.__palette = palette + self.__values = values + + self.__gradient = LegendGradient(palette, parent, orientation) + self.__labels_layout = QtGui.QGraphicsLinearLayout(orientation) + + str_vals = self._format_values(values) + + self.__start_label = LegendItemTitle(str_vals[0], parent, font=font) + self.__end_label = LegendItemTitle(str_vals[1], parent, font=font) + self.__labels_layout.addItem(self.__start_label) + self.__labels_layout.addStretch(1) + self.__labels_layout.addItem(self.__end_label) + + # Gradient should be to the left, then labels on the right if vertical + if orientation == Qt.Vertical: + self.addItem(self.__gradient) + self.addItem(self.__labels_layout) + # Gradient should be on the bottom, labels on top if horizontal + elif orientation == Qt.Horizontal: + self.addItem(self.__labels_layout) + self.addItem(self.__gradient) + + @staticmethod + def _format_values(values): + """Get the formatted values to output.""" + return ['{:.3f}'.format(v) for v in values] + + +class Legend(Anchorable): + """Base legend class. + + This class provides common attributes for any legend derivates: + - Behaviour on `QGraphicsScene` + - Appearance of legend + + If you have access to the `domain` property, the `LegendBuilder` class + can be used to automatically build a legend for you. + + Parameters + ---------- + parent : QtGui.QGraphicsItem, optional + orientation : Qt.Orientation, optional + The default orientation is vertical + domain : Orange.data.domain.Domain, optional + This field is left optional as in some cases, we may want to simply + pass in a list that represents the legend. + items : Iterable[QtGui.QColor, str] + bg_color : QtGui.QColor, optional + font : QtGui.QFont, optional + color_indicator_cls : ColorIndicator + The color indicator class that will be used to render the indicators. + + See Also + -------- + OWDiscreteLegend + OWContinuousLegend + OWContinuousLegend + + Notes + ----- + .. Warning:: If the domain parameter is supplied, the items parameter will + be ignored. + + """ + + def __init__(self, parent=None, orientation=Qt.Vertical, domain=None, + items=None, bg_color=QtGui.QColor(232, 232, 232, 196), + font=None, color_indicator_cls=LegendItemSquare, **kwargs): + super().__init__(parent, **kwargs) + + self.orientation = orientation + self.bg_color = QtGui.QBrush(bg_color) + self.color_indicator_cls = color_indicator_cls + + # Set default font if none is given + if font is None: + self.font = QtGui.QFont() + self.font.setPointSize(10) + else: + self.font = font + + self.setFlags(QtGui.QGraphicsWidget.ItemIsMovable | + QtGui.QGraphicsItem.ItemIgnoresTransformations) + + if domain is not None: + self.set_domain(domain) + elif items is not None: + self.set_items(items) + + def _clear_layout(self): + self._layout = None + for child in self.children(): + child.setParent(None) + + def _setup_layout(self): + self._clear_layout() + + self._layout = QtGui.QGraphicsLinearLayout(self.orientation) + self._layout.setContentsMargins(10, 5, 10, 5) + # If horizontal, there needs to be horizontal space between the items + if self.orientation == Qt.Horizontal: + self._layout.setSpacing(10) + # If vertical spacing, vertical space is provided by child layouts + else: + self._layout.setSpacing(0) + self.setLayout(self._layout) + + def set_domain(self, domain): + """Handle receiving the domain object. + + Parameters + ---------- + domain : Orange.data.domain.Domain + + Returns + ------- + + Raises + ------ + AttributeError + If the domain does not contain the correct type of class variable. + + """ + raise NotImplemented() + + def set_items(self, values): + """Handle receiving an array of items. + + Parameters + ---------- + values : iterable[object, QtGui.QColor] + + Returns + ------- + + """ + raise NotImplemented() + + @staticmethod + def _convert_to_color(obj): + if isinstance(obj, QtGui.QColor): + return obj + elif isinstance(obj, tuple) or isinstance(obj, list): + assert len(obj) in (3, 4) + return QtGui.QColor(*obj) + else: + return QtGui.QColor(obj) + + def paint(self, painter, options, widget=None): + painter.save() + pen = QtGui.QPen(QtGui.QColor(196, 197, 193, 200), 1) + brush = QtGui.QBrush(QtGui.QColor(self.bg_color)) + + painter.setPen(pen) + painter.setBrush(brush) + painter.drawRect(self.contentsRect()) + painter.restore() + + +class OWDiscreteLegend(Legend): + """Discrete legend. + + See Also + -------- + Legend + OWContinuousLegend + + """ + + def set_domain(self, domain): + class_var = domain.class_var + + if not class_var.is_discrete: + raise AttributeError('[OWDiscreteLegend] The class var provided ' + 'was not discrete.') + + self.set_items(zip(class_var.values, class_var.colors.tolist())) + + def set_items(self, values): + self._setup_layout() + for class_name, color in values: + legend_item = LegendItem( + color=self._convert_to_color(color), + title=class_name, + parent=self, + color_indicator_cls=self.color_indicator_cls, + font=self.font + ) + self._layout.addItem(legend_item) + + +class OWContinuousLegend(Legend): + """Continuous legend. + + See Also + -------- + Legend + OWDiscreteLegend + + """ + + def __init__(self, *args, **kwargs): + # Variables used in the `set_` methods must be set before calling super + self.__range = kwargs.get('range', ()) + + super().__init__(*args, **kwargs) + + self._layout.setContentsMargins(10, 10, 10, 10) + + def set_domain(self, domain): + class_var = domain.class_var + + if not class_var.is_continuous: + raise AttributeError('[OWContinuousLegend] The class var provided ' + 'was not continuous.') + + # The first and last values must represent the range, the rest should + # be dummy variables, as they are not shown anywhere + values = self.__range + + start, end, pass_through_black = class_var.colors + # If pass through black, push black in between and add index to vals + if pass_through_black: + colors = [self._convert_to_color(c) for c + in [start, '#000000', end]] + values.insert(1, -1) + else: + colors = [self._convert_to_color(c) for c in [start, end]] + + self.set_items(list(zip(values, colors))) + + def set_items(self, values): + vals, colors = list(zip(*values)) + + # If the orientation is vertical, it makes more sense for the smaller + # value to be shown on the bottom + if self.orientation == Qt.Vertical and vals[0] < vals[len(vals) - 1]: + colors, vals = list(reversed(colors)), list(reversed(vals)) + + self._setup_layout() + self._layout.addItem(ContinuousLegendItem( + palette=colors, + values=vals, + parent=self, + font=self.font, + orientation=self.orientation + )) + + +class OWBinnedContinuousLegend(Legend): + def set_domain(self, domain): + pass + + def set_items(self, values): + pass diff --git a/Orange/widgets/visualize/widgetutils/common/scene.py b/Orange/widgets/visualize/widgetutils/common/scene.py new file mode 100644 index 00000000000..ca10a1993fe --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/common/scene.py @@ -0,0 +1,25 @@ +from PyQt4 import QtGui + + +class UpdateItemsOnSelectGraphicsScene(QtGui.QGraphicsScene): + """Calls the selection_changed method on items. + + Whenever the scene selection changes, this view will call the + ˙selection_changed˙ method on any item on the scene. + + Notes + ----- + ..Note:: I suspect this is completely unncessary, but have not been able to + find a reasonable way to keep the selection logic inside the actual + `QGraphicsItem` objects + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.selectionChanged.connect(self.__handle_selection) + + def __handle_selection(self): + for item in self.items(): + if hasattr(item, 'selection_changed'): + item.selection_changed() diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/widgetutils/common/view.py new file mode 100644 index 00000000000..e080828f757 --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/common/view.py @@ -0,0 +1,180 @@ +from itertools import repeat + +import numpy as np +from PyQt4 import QtGui +from PyQt4.QtCore import Qt + + +class ZoomableGraphicsView(QtGui.QGraphicsView): + """Zoomable graphics view. + + Composable graphics view that adds zoom functionality. + + It also handles automatic resizing of content whenever the window is + resized. + + Right click will reset the zoom to a factor where the entire scene is + visible. + + Parameters + ---------- + scene : QtGui.QGraphicsScene + padding : int or tuple, optional + Specify the padding around the drawn widgets. Can be an int, or tuple, + the tuple can contain either 2 or 4 elements. + + Notes + ----- + - This view will consume wheel scrolling and right mouse click events. + + """ + + def __init__(self, scene, padding=(0, 0), **kwargs): + self.zoom = 1 + self.scale_factor = 1 / 16 + # zoomout limit prevents the zoom factor to become negative, which + # results in the canvas being flipped over the x axis + self.__zoomout_limit_reached = False + # Does the view need to recalculate the initial scale factor + self.__needs_to_recalculate_initial = True + self.__initial_zoom = -1 + + self.__central_widget = None + self.__set_padding(padding) + + super().__init__(scene, **kwargs) + + def resizeEvent(self, ev): + super().resizeEvent(ev) + self.__needs_to_recalculate_initial = True + + def wheelEvent(self, ev): + self.__handle_zoom(ev.delta()) + super().wheelEvent(ev) + + def mousePressEvent(self, ev): + # right click resets the zoom factor + if ev.button() == Qt.RightButton: + self.reset_zoom() + super().mousePressEvent(ev) + + def keyPressEvent(self, ev): + if ev.key() == Qt.Key_Plus: + self.__handle_zoom(1) + elif ev.key() == Qt.Key_Minus: + self.__handle_zoom(-1) + + super().keyPressEvent(ev) + + def __set_padding(self, padding): + # Allow for multiple formats of padding for convenience + if isinstance(padding, int): + padding = list(repeat(padding, 4)) + elif isinstance(padding, list) or isinstance(padding, tuple): + if len(padding) == 2: + padding = (*padding, *padding) + else: + padding = 0, 0, 0, 0 + + l, t, r, b = padding + self.__padding = -l, -t, r, b + + def __handle_zoom(self, direction): + """Handle zoom event, direction is positive if zooming in, otherwise + negative.""" + if self.__zooming_in(direction): + self.__reset_zoomout_limit() + if self.__zoomout_limit_reached and self.__zooming_out(direction): + return + + self.zoom += np.sign(direction) * self.scale_factor + if self.zoom <= 0: + self.__zoomout_limit_reached = True + self.zoom += self.scale_factor + else: + self.setTransformationAnchor(self.AnchorUnderMouse) + self.setTransform(QtGui.QTransform().scale(self.zoom, self.zoom)) + + @staticmethod + def __zooming_out(direction): + return direction < 0 + + def __zooming_in(self, ev): + return not self.__zooming_out(ev) + + def __reset_zoomout_limit(self): + self.__zoomout_limit_reached = False + + def set_central_widget(self, widget): + self.__central_widget = widget + + def central_widget_rect(self): + """Get the bounding box of the central widget. + + If a central widget and padding are set, this method calculates the + rect containing both of them. This is useful because if the padding was + added directly onto the widget, the padding would be rescaled as well. + + If the central widget is not set, return the scene rect instead. + + Returns + ------- + QtCore.QRectF + + """ + if self.__central_widget is None: + return self.scene().itemsBoundingRect().adjusted(*self.__padding) + return self.__central_widget.boundingRect().adjusted(*self.__padding) + + def recalculate_and_fit(self): + """Recalculate the optimal zoom and fits the content into view. + + Should be called if the scene contents change, so that the optimal zoom + can be recalculated. + + Returns + ------- + + """ + if self.__central_widget is not None: + self.fitInView(self.central_widget_rect(), Qt.KeepAspectRatio) + else: + self.fitInView(self.scene().sceneRect(), Qt.KeepAspectRatio) + + self.__initial_zoom = self.matrix().m11() + self.zoom = self.__initial_zoom + + def reset_zoom(self): + """Reset the zoom to the optimal factor.""" + self.zoom = self.__initial_zoom + self.__zoomout_limit_reached = False + + if self.__needs_to_recalculate_initial: + self.recalculate_and_fit() + else: + self.setTransform(QtGui.QTransform().scale(self.zoom, self.zoom)) + + +class PannableGraphicsView(QtGui.QGraphicsView): + """Pannable graphics view. + + Enables panning the graphics view. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setDragMode(QtGui.QGraphicsView.ScrollHandDrag) + + def enterEvent(self, ev): + self.viewport().setCursor(Qt.ArrowCursor) + super().enterEvent(ev) + + def mouseReleaseEvent(self, ev): + super().mouseReleaseEvent(ev) + self.viewport().setCursor(Qt.ArrowCursor) + + +class PreventDefaultWheelEvent(QtGui.QGraphicsView): + def wheelEvent(self, ev): + ev.accept() diff --git a/Orange/widgets/visualize/widgetutils/tree/__init__.py b/Orange/widgets/visualize/widgetutils/tree/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py new file mode 100644 index 00000000000..c113e096f55 --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/tree/rules.py @@ -0,0 +1,181 @@ +class Rule: + """The base Rule class for tree rules.""" + + def merge_with(self, rule): + """Merge the current rule with the given rule. + + Parameters + ---------- + rule : Rule + + Returns + ------- + Rule + + """ + raise NotImplemented() + + +class DiscreteRule(Rule): + """Discrete rule class for handling Indicator rules. + + Parameters + ---------- + attr_name : str + eq : bool + Should indicate whether or not the rule equals the value or not. + value : object + + Examples + -------- + Age = 30 + >>> rule = DiscreteRule('age', True, 30) + Name ≠ John + >>> rule = DiscreteRule('name', False, 'John') + + Notes + ----- + - Merging discrete rules is currently not implemented, the new rule is + simply returned and a warning is printed to stderr. + + """ + + def __init__(self, attr_name, eq, value): + self.attr_name = attr_name + self.sign = eq + self.value = value + + def merge_with(self, rule): + # It does not make sense to merge discrete rules, since they can only + # be eq or not eq. + from sys import stderr + print('WARNING: Merged two discrete rules `%s` and `%s`' + % (self, rule), file=stderr) + return rule + + def __str__(self): + return '{} {} {}'.format( + self.attr_name, '=' if self.sign else '≠', self.value) + + +class ContinuousRule(Rule): + """Continuous rule class for handling numeric rules. + + Parameters + ---------- + attr_name : str + gt : bool + Should indicate whether the variable must be greater than the value. + value : int + inclusive : bool, optional + Should the variable range include the value or not + (LT <> LTE | GT <> GTE). Default is False. + + Examples + -------- + x ≤ 30 + >>> rule = ContinuousRule('age', False, 30, inclusive=True) + x > 30 + >>> rule = ContinuousRule('age', True, 30) + + Notes + ----- + - Continuous rules can currently only be merged with other continuous + rules. + + """ + + def __init__(self, attr_name, gt, value, inclusive=False): + self.attr_name = attr_name + self.sign = gt + self.value = value + self.inclusive = inclusive + + def merge_with(self, rule): + if not isinstance(rule, ContinuousRule): + raise NotImplemented('Continuous rules can currently only be ' + 'merged with other continuous rules') + # Handle when both have same sign + if self.sign == rule.sign: + # When both are GT + if self.sign is True: + larger = self.value if self.value > rule.value else rule.value + return ContinuousRule(self.attr_name, self.sign, larger) + # When both are LT + else: + smaller = self.value if self.value < rule.value else rule.value + return ContinuousRule(self.attr_name, self.sign, smaller) + # When they have different signs we need to return an interval rule + else: + lt_rule = self if self.sign is False else rule + gt_rule = self if lt_rule != self else rule + return IntervalRule(self.attr_name, gt_rule, lt_rule) + + def __str__(self): + return '%s %s %.3f' % ( + self.attr_name, '>' if self.sign else '≤', self.value) + + +class IntervalRule(Rule): + """Interval rule class for ranges of continuous values. + + Parameters + ---------- + attr_name : str + left_rule : ContinuousRule + The smaller (left) part of the interval. + right_rule : ContinuousRule + The larger (right) part of the interval. + + Examples + -------- + 1 ≤ x < 3 + >>> rule = IntervalRule('Rule', + >>> ContinuousRule('Rule', True, 1, inclusive=True), + >>> ContinuousRule('Rule', False, 3)) + + Notes + ----- + - Currently, only cases which appear in classification and regression + trees are implemented. An interval can not be made up of two parts + (e.g. (-∞, -1) ∪ (1, ∞)). + + """ + + def __init__(self, attr_name, left_rule, right_rule): + if not isinstance(left_rule, ContinuousRule): + raise AttributeError( + 'The left rule must be an instance of the `ContinuousRule` ' + 'class.') + if not isinstance(right_rule, ContinuousRule): + raise AttributeError( + 'The right rule must be an instance of the `ContinuousRule` ' + 'class.') + + self.attr_name = attr_name + self.left_rule = left_rule + self.right_rule = right_rule + + def merge_with(self, rule): + if isinstance(rule, ContinuousRule): + if rule.sign: + return IntervalRule( + self.attr_name, self.left_rule.merge_with(rule), + self.right_rule) + else: + return IntervalRule( + self.attr_name, self.left_rule, + self.right_rule.merge_with(rule)) + + elif isinstance(rule, IntervalRule): + return IntervalRule( + self.attr_name, + self.left_rule.merge_with(rule.left_rule), + self.right_rule.merge_with(rule.right_rule)) + + def __str__(self): + return '{} ∈ {}{:.3}, {:.3}{}'.format( + self.attr_name, + '[' if self.left_rule.inclusive else '(', self.left_rule.value, + self.right_rule.value, ']' if self.right_rule.inclusive else ')' + ) diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py new file mode 100644 index 00000000000..61886eab7ea --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py @@ -0,0 +1,302 @@ +from collections import OrderedDict +from functools import lru_cache + +import numpy as np +from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter + +from Orange.preprocess.transformation import Indicator +from Orange.widgets.visualize.widgetutils.tree.rules import ( + DiscreteRule, + ContinuousRule +) + + +class SklTreeAdapter(TreeAdapter): + """Sklear Tree Adapter. + + An abstraction on top of the scikit learn classification tree. + + Parameters + ---------- + tree : sklearn.tree._tree.Tree + The raw sklearn classification tree. + domain : Orange.base.domain + The Orange domain that comes with the model. + adjust_weight : Callable, optional + If you want finer control over the weights of individual nodes you can + pass in a function that takes the existsing weight and modifies it. + The given function must have signture :: Number -> Number + + """ + + ROOT_PARENT = -1 + NO_CHILD = -1 + FEATURE_UNDEFINED = -2 + + def __init__(self, tree, domain, adjust_weight=lambda x: x): + self._tree = tree + self._domain = domain + self._adjust_weight = adjust_weight + + # clear memoized functions + self.weight.cache_clear() + self._adjusted_child_weight.cache_clear() + self.parent.cache_clear() + + self._all_leaves = None + + @lru_cache(maxsize=1024) + def weight(self, node): + return self._adjust_weight(self.num_samples(node)) / \ + self._adjusted_child_weight(self.parent(node)) + + @lru_cache(maxsize=1024) + def _adjusted_child_weight(self, node): + """Helps when dealing with adjusted weights. + + It is needed when dealing with non linear weights e.g. when calculating + the log weight, the sum of logs of all the children will not be equal + to the log of all the data instances. + A simple example: log(2) + log(2) != log(4) + + Parameters + ---------- + node : int + The label of the node. + + Returns + ------- + float + The sum of all of the weights of the children of a given node. + + """ + return sum(self._adjust_weight(self.num_samples(c)) + for c in self.children(node)) \ + if self.has_children(node) else 0 + + def num_samples(self, node): + return self._tree.n_node_samples[node] + + @lru_cache(maxsize=1024) + def parent(self, node): + for children in (self._tree.children_left, self._tree.children_right): + try: + return (children == node).nonzero()[0][0] + except IndexError: + continue + return self.ROOT_PARENT + + def has_children(self, node): + return self._tree.children_left[node] != self.NO_CHILD \ + or self._tree.children_right[node] != self.NO_CHILD + + def children(self, node): + if self.has_children(node): + return self.__left_child(node), self.__right_child(node) + return () + + def __left_child(self, node): + return self._tree.children_left[node] + + def __right_child(self, node): + return self._tree.children_right[node] + + def get_distribution(self, node): + return self._tree.value[node] + + def get_impurity(self, node): + return self._tree.impurity[node] + + @property + def max_depth(self): + return self._tree.max_depth + + @property + def num_nodes(self): + return self._tree.node_count + + @property + def root(self): + return 0 + + @property + def domain(self): + return self._domain + + @lru_cache(maxsize=1024) + def rules(self, node): + if node != self.root: + parent = self.parent(node) + # Convert the parent list of rules into an ordered dict + pr = OrderedDict([(r.attr_name, r) for r in self.rules(parent)]) + + parent_attr = self.attribute(parent) + # Get the parent attribute type + parent_attr_cv = parent_attr.compute_value + + is_left_child = self.__left_child(parent) == node + + # The parent split variable is discrete + if isinstance(parent_attr_cv, Indicator) and \ + hasattr(parent_attr_cv.variable, 'values'): + values = parent_attr_cv.variable.values + attr_name = parent_attr_cv.variable.name + eq = not is_left_child * (len(values) != 2) + value = values[abs(parent_attr_cv.value - + is_left_child * (len(values) == 2))] + new_rule = DiscreteRule(attr_name, eq, value) + # Since discrete variables should appear in their own lines + # they must not be merged, so the dict key is set with the + # value, so the same keys can exist with different values + # e.g. #legs ≠ 2 and #legs ≠ 4 + attr_name = attr_name + '_' + value + # The parent split variable is continuous + else: + attr_name = parent_attr.name + sign = not is_left_child + value = self._tree.threshold[self.parent(node)] + new_rule = ContinuousRule(attr_name, sign, value, + inclusive=is_left_child) + + # Check if a rule with that attribute exists + if attr_name in pr: + pr[attr_name] = pr[attr_name].merge_with(new_rule) + pr.move_to_end(attr_name) + else: + pr[attr_name] = new_rule + + return list(pr.values()) + else: + return [] + + def attribute(self, node): + feature_idx = self.splitting_attribute(node) + if feature_idx != self.FEATURE_UNDEFINED: + return self.domain.attributes[self.splitting_attribute(node)] + + def splitting_attribute(self, node): + return self._tree.feature[node] + + @lru_cache(maxsize=1024) + def leaves(self, node): + start, stop = self._subnode_range(node) + if start == stop: + # leaf + return np.array([node], dtype=int) + else: + is_leaf = self._tree.children_left[start:stop] == self.NO_CHILD + assert np.flatnonzero(is_leaf).size > 0 + return start + np.flatnonzero(is_leaf) + + def _subnode_range(self, node): + """ + Get the range of indices where there are subnodes of the given node. + + See Also + -------- + Orange.widgets.classify.owclassificationtreegraph.OWTreeGraph + """ + + def find_largest_idx(n): + """It is necessary to locate the node with the largest index in the + children in order to get a good range. This is necessary with trees + that are not right aligned, which can happen when visualising + random forest trees.""" + if self._tree.children_left[n] == self.NO_CHILD: + return n + + l_node = find_largest_idx(self._tree.children_left[n]) + r_node = find_largest_idx(self._tree.children_right[n]) + + return l_node if l_node > r_node else r_node + + right = left = node + if self._tree.children_left[left] == self.NO_CHILD: + assert self._tree.children_right[node] == self.NO_CHILD + return node, node + else: + left = self._tree.children_left[left] + right = find_largest_idx(right) + + return left, right + 1 + + def get_samples_in_leaves(self, data): + """Get an array of instance indices that belong to each leaf. + + For a given dataset X, separate the instances out into an array, so + they are grouped together based on what leaf they belong to. + + Examples + -------- + Given a tree with two leaf nodes ( A <- R -> B ) and the dataset X = + [ 10, 20, 30, 40, 50, 60 ], where 10, 20 and 40 belong to leaf A, and + the rest to leaf B, the following structure will be returned (where + array is the numpy array): + [array([ 0, 1, 3 ]), array([ 2, 4, 5 ])] + + The first array represents the indices of the values that belong to the + first leaft, so calling X[ 0, 1, 3 ] = [ 10, 20, 40 ] + + Parameters + ---------- + data + A matrix containing the data instances. + + Returns + ------- + np.array + The indices of instances belonging to a given leaf. + + """ + + def assign(node_id, indices): + if self._tree.children_left[node_id] == self.NO_CHILD: + return [indices] + else: + feature_idx = self._tree.feature[node_id] + thresh = self._tree.threshold[node_id] + + column = data[indices, feature_idx] + leftmask = column <= thresh + leftind = assign(self._tree.children_left[node_id], + indices[leftmask]) + rightind = assign(self._tree.children_right[node_id], + indices[~leftmask]) + return list.__iadd__(leftind, rightind) + + # TODO this kind of cache can lead to all sorts of problems, but numpy + # arrays are unhashable, and this gives huge performance boosts + # also this would only become a problem if the function required to + # handle multiple datasets, which it doesn't, it just deals with the + # one the classification tree was fit to. + if self._all_leaves is not None: + return self._all_leaves + + n, _ = data.shape + + items = np.arange(n, dtype=int) + leaf_indices = assign(0, items) + self._all_leaves = leaf_indices + return leaf_indices + + def get_instances_in_nodes(self, dataset, nodes): + if not isinstance(nodes, (list, tuple)): + nodes = [nodes] + + node_leaves = [self.leaves(n.label) for n in nodes] + if len(node_leaves) > 0: + # get the leaves of the selected tree node + node_leaves = np.unique(np.hstack(node_leaves)) + + all_leaves = self.leaves(self.root) + + indices = np.searchsorted(all_leaves, node_leaves, side='left') + # all the leaf samples for each leaf + leaf_samples = self.get_samples_in_leaves(dataset.X) + # filter out the leaf samples array that are not selected + leaf_samples = [leaf_samples[i] for i in indices] + indices = np.hstack(leaf_samples) + else: + indices = [] + + return dataset[indices] if len(indices) else None diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py b/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py new file mode 100644 index 00000000000..a6a4a86e0e4 --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py @@ -0,0 +1,157 @@ +import unittest + +from Orange.widgets.visualize.widgetutils.tree.rules import ( + ContinuousRule, + IntervalRule, +) + + +class TestRules(unittest.TestCase): + # CONTINUOUS RULES + def test_merging_two_gt_continuous_rules(self): + """Merging `x > 1` and `x > 2` should produce `x > 2`.""" + rule1 = ContinuousRule('Rule', True, 1) + rule2 = ContinuousRule('Rule', True, 2) + new_rule = rule1.merge_with(rule2) + self.assertEqual(new_rule.value, 2) + + def test_merging_gt_with_gte_continuous_rule(self): + """Merging `x > 1` and `x ≥ 1` should produce `x > 1`.""" + rule1 = ContinuousRule('Rule', True, 1, inclusive=True) + rule2 = ContinuousRule('Rule', True, 1, inclusive=False) + new_rule = rule1.merge_with(rule2) + self.assertEqual(new_rule.inclusive, False) + + def test_merging_two_lt_continuous_rules(self): + """Merging `x < 1` and `x < 2` should produce `x < 1`.""" + rule1 = ContinuousRule('Rule', False, 1) + rule2 = ContinuousRule('Rule', False, 2) + new_rule = rule1.merge_with(rule2) + self.assertEqual(new_rule.value, 1) + + def test_merging_lt_with_lte_rule(self): + """Merging `x < 1` and `x ≤ 1` should produce `x < 1`.""" + rule1 = ContinuousRule('Rule', False, 1, inclusive=True) + rule2 = ContinuousRule('Rule', False, 1, inclusive=False) + new_rule = rule1.merge_with(rule2) + self.assertEqual(new_rule.inclusive, False) + + def test_merging_lt_with_gt_continuous_rules(self): + """Merging `x > 1` and `x < 2` should produce `1 < x < 2`.""" + rule1 = ContinuousRule('Rule', True, 1) + rule2 = ContinuousRule('Rule', False, 2) + new_rule = rule1.merge_with(rule2) + self.assertIsInstance(new_rule, IntervalRule) + self.assertEquals(new_rule.left_rule, rule1) + self.assertEquals(new_rule.right_rule, rule2) + + # INTERVAL RULES + def test_merging_interval_rule_with_smaller_continuous_rule(self): + """Merging `1 < x < 2` and `x < 3` should produce `1 < x < 2`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 1), + ContinuousRule('Rule', False, 2)) + rule2 = ContinuousRule('Rule', False, 2) + new_rule = rule1.merge_with(rule2) + self.assertIsInstance(new_rule, IntervalRule) + self.assertEquals(new_rule.right_rule.value, 2) + + def test_merging_interval_rule_with_larger_continuous_rule(self): + """Merging `1 < x < 2` and `x < 3` should produce `1 < x < 2`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 1), + ContinuousRule('Rule', False, 2)) + rule2 = ContinuousRule('Rule', False, 3) + new_rule = rule1.merge_with(rule2) + self.assertIsInstance(new_rule, IntervalRule) + self.assertEquals(new_rule.left_rule.value, 1) + + def test_merging_interval_rule_with_larger_lt_continuous_rule(self): + """Merging `0 < x < 3` and `x > 1` should produce `1 < x < 3`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 0), + ContinuousRule('Rule', False, 3)) + rule2 = ContinuousRule('Rule', True, 1) + new_rule = rule1.merge_with(rule2) + self.assertIsInstance(new_rule, IntervalRule) + self.assertEquals(new_rule.left_rule.value, 1) + + def test_merging_interval_rule_with_smaller_gt_continuous_rule(self): + """Merging `0 < x < 3` and `x < 2` should produce `0 < x < 2`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 0), + ContinuousRule('Rule', False, 3)) + rule2 = ContinuousRule('Rule', False, 2) + new_rule = rule1.merge_with(rule2) + self.assertIsInstance(new_rule, IntervalRule) + self.assertEquals(new_rule.right_rule.value, 2) + + def test_merging_interval_rules_with_smaller_lt_component(self): + """Merging `1 < x < 2` and `0 < x < 2` should produce `1 < x < 2`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 1), + ContinuousRule('Rule', False, 2)) + rule2 = IntervalRule('Rule', + ContinuousRule('Rule', True, 0), + ContinuousRule('Rule', False, 2)) + new_rule = rule1.merge_with(rule2) + self.assertEquals(new_rule.left_rule.value, 1) + self.assertEquals(new_rule.right_rule.value, 2) + + def test_merging_interval_rules_with_larger_lt_component(self): + """Merging `0 < x < 4` and `1 < x < 4` should produce `1 < x < 4`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 0), + ContinuousRule('Rule', False, 4)) + rule2 = IntervalRule('Rule', + ContinuousRule('Rule', True, 1), + ContinuousRule('Rule', False, 4)) + new_rule = rule1.merge_with(rule2) + self.assertEquals(new_rule.left_rule.value, 1) + self.assertEquals(new_rule.right_rule.value, 4) + + def test_merging_interval_rules_generally(self): + """Merging `0 < x < 4` and `2 < x < 6` should produce `2 < x < 4`.""" + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 0), + ContinuousRule('Rule', False, 4)) + rule2 = IntervalRule('Rule', + ContinuousRule('Rule', True, 2), + ContinuousRule('Rule', False, 6)) + new_rule = rule1.merge_with(rule2) + self.assertEquals(new_rule.left_rule.value, 2) + self.assertEquals(new_rule.right_rule.value, 4) + + # ALL RULES + def test_merge_commutativity_on_continuous_rules(self): + rule1 = ContinuousRule('Rule1', True, 1) + rule2 = ContinuousRule('Rule1', True, 2) + new_rule1 = rule1.merge_with(rule2) + new_rule2 = rule2.merge_with(rule1) + self.assertEqual(new_rule1.value, new_rule2.value) + + def test_merge_commutativity_on_interval_rules(self): + rule1 = IntervalRule('Rule', + ContinuousRule('Rule', True, 0), + ContinuousRule('Rule', False, 4)) + rule2 = IntervalRule('Rule', + ContinuousRule('Rule', True, 2), + ContinuousRule('Rule', False, 6)) + new_rule1 = rule1.merge_with(rule2) + new_rule2 = rule2.merge_with(rule1) + self.assertEquals(new_rule1.left_rule.value, + new_rule2.left_rule.value) + self.assertEquals(new_rule1.right_rule.value, + new_rule2.right_rule.value) + + def test_merge_keeps_sign_on_continuous_rules(self): + rule1 = ContinuousRule('Rule1', True, 1) + rule2 = ContinuousRule('Rule1', True, 2) + new_rule = rule1.merge_with(rule2) + self.assertEquals(new_rule.sign, True) + + def test_merge_keeps_attr_name_on_continuous_rules(self): + rule1 = ContinuousRule('Rule1', True, 1) + rule2 = ContinuousRule('Rule1', True, 2) + new_rule = rule1.merge_with(rule2) + self.assertEquals(new_rule.attr_name, 'Rule1') diff --git a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py new file mode 100644 index 00000000000..7bf07368c5f --- /dev/null +++ b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py @@ -0,0 +1,274 @@ +from abc import ABCMeta, abstractmethod + + +class TreeAdapter(metaclass=ABCMeta): + """Base class for tree representation. + + Any subclass should implement the methods listed in this base class. Note + that some simple methods do not need to reimplemented e.g. is_leaf since + it that is the opposite of has_children. + + """ + + ROOT_PARENT = -1 + NO_CHILD = -1 + FEATURE_UNDEFINED = -2 + + @abstractmethod + def weight(self, node): + """Get the weight of the given node. + + The weights of the children always sum up to 1. + + Parameters + ---------- + node : object + The label of the node. + + Returns + ------- + float + The weight of the node relative to its siblings. + + """ + pass + + @abstractmethod + def num_samples(self, node): + """Get the number of samples that a given node contains. + + Parameters + ---------- + node : object + A unique identifier of a node. + + Returns + ------- + int + + """ + pass + + @abstractmethod + def parent(self, node): + """Get the parent of a given node. Return -1 if the node is the root. + + Parameters + ---------- + node : object + + Returns + ------- + object + + """ + pass + + @abstractmethod + def has_children(self, node): + """Check if the given node has any children. + + Parameters + ---------- + node : object + + Returns + ------- + bool + + """ + pass + + def is_leaf(self, node): + """Check if the given node is a leaf node. + + Parameters + ---------- + node : object + + Returns + ------- + object + + """ + return not self.has_children(node) + + @abstractmethod + def children(self, node): + """Get all the children of a given node. + + Parameters + ---------- + node : object + + Returns + ------- + Iterable[object] + A iterable object containing the labels of the child nodes. + + """ + pass + + @abstractmethod + def get_distribution(self, node): + """Get the distribution of types for a given node. + + This may be the number of nodes that belong to each different classe in + a node. + + Parameters + ---------- + node : object + + Returns + ------- + Iterable[int, ...] + The return type is an iterable with as many fields as there are + different classes in the given node. The values of the fields are + the number of nodes that belong to a given class inside the node. + + """ + pass + + @abstractmethod + def get_impurity(self, node): + """Get the impurity of a given node. + + Parameters + ---------- + node : object + + Returns + ------- + object + + """ + pass + + @abstractmethod + def rules(self, node): + """Get a list of rules that define the given node. + + Parameters + ---------- + node : object + + Returns + ------- + Iterable[Rule] + A list of Rule objects, can be of any type. + + """ + pass + + @abstractmethod + def attribute(self, node): + """Get the attribute that splits the given tree. + + Parameters + ---------- + node + + Returns + ------- + + """ + pass + + def is_root(self, node): + """Check if a given node is the root node. + + Parameters + ---------- + node + + Returns + ------- + + """ + return node == self.root + + @abstractmethod + def leaves(self, node): + """Get all the leavse that belong to the subtree of a given node. + + Parameters + ---------- + node + + Returns + ------- + + """ + pass + + @abstractmethod + def get_instances_in_nodes(self, dataset, nodes): + """Get all the instances belonging to a set of nodes for a given + dataset. + + Parameters + ---------- + dataset : Table + A Orange Table dataset. + nodes : iterable[TreeNode] + A list of tree nodes for which we want the instances. + + Returns + ------- + + """ + pass + + @property + @abstractmethod + def max_depth(self): + """Get the maximum depth that the tree reaches. + + Returns + ------- + int + + """ + pass + + @property + @abstractmethod + def num_nodes(self): + """Get the total number of nodes that the tree contains. + + This does not mean the number of samples inside the entire tree, just + the number of nodes. + + Returns + ------- + int + + """ + pass + + @property + @abstractmethod + def root(self): + """Get the label of the root node. + + Returns + ------- + object + + """ + pass + + @property + @abstractmethod + def domain(self): + """Get the domain of the given tree. + + The domain contains information about the classes what the tree + represents. + + Returns + ------- + + """ + pass From df43036337d7a8ee6265bc8fb523f9f684aa5124 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 12 Jul 2016 10:31:30 +0200 Subject: [PATCH 02/12] Pythagorean tree: Add memoize_method to avoid memory leaks --- Orange/misc/cache.py | 41 +++++++++++++++++++ .../widgetutils/tree/skltreeadapter.py | 18 ++++---- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py index bf5f295b7d9..2dc9f39751d 100644 --- a/Orange/misc/cache.py +++ b/Orange/misc/cache.py @@ -1,3 +1,7 @@ +from functools import wraps, lru_cache +import weakref + + def single_cache(f): last_args = () last_kwargs = set() @@ -14,3 +18,40 @@ def cached(*args, **kwargs): return last_result return cached + + +def memoize_method(*lru_args, **lru_kwargs): + """Memoize methods without keeping reference to `self`. + + Parameters + ---------- + lru_args + lru_kwargs + + Returns + ------- + + See Also + -------- + https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-class-methods-release-object + + """ + def _decorator(func): + + @wraps(func) + def _wrapped_func(self, *args, **kwargs): + self_weak = weakref.ref(self) + # We're storing the wrapped method inside the instance. If we had + # a strong reference to self the instance would never die. + + @wraps(func) + @lru_cache(*lru_args, **lru_kwargs) + def _cached_method(*args, **kwargs): + return func(self_weak(), *args, **kwargs) + + setattr(self, func.__name__, _cached_method) + return _cached_method(*args, **kwargs) + + return _wrapped_func + + return _decorator diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py index 61886eab7ea..307fad83d1c 100644 --- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py +++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py @@ -1,7 +1,8 @@ from collections import OrderedDict -from functools import lru_cache import numpy as np + +from Orange.misc.cache import memoize_method from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter from Orange.preprocess.transformation import Indicator @@ -38,19 +39,14 @@ def __init__(self, tree, domain, adjust_weight=lambda x: x): self._domain = domain self._adjust_weight = adjust_weight - # clear memoized functions - self.weight.cache_clear() - self._adjusted_child_weight.cache_clear() - self.parent.cache_clear() - self._all_leaves = None - @lru_cache(maxsize=1024) + @memoize_method(maxsize=1024) def weight(self, node): return self._adjust_weight(self.num_samples(node)) / \ self._adjusted_child_weight(self.parent(node)) - @lru_cache(maxsize=1024) + @memoize_method(maxsize=1024) def _adjusted_child_weight(self, node): """Helps when dealing with adjusted weights. @@ -77,7 +73,7 @@ def _adjusted_child_weight(self, node): def num_samples(self, node): return self._tree.n_node_samples[node] - @lru_cache(maxsize=1024) + @memoize_method(maxsize=1024) def parent(self, node): for children in (self._tree.children_left, self._tree.children_right): try: @@ -123,7 +119,7 @@ def root(self): def domain(self): return self._domain - @lru_cache(maxsize=1024) + @memoize_method(maxsize=1024) def rules(self, node): if node != self.root: parent = self.parent(node) @@ -177,7 +173,7 @@ def attribute(self, node): def splitting_attribute(self, node): return self._tree.feature[node] - @lru_cache(maxsize=1024) + @memoize_method(maxsize=1024) def leaves(self, node): start, stop = self._subnode_range(node) if start == stop: From d860fcd4142b8bd5867043abf4eccb077e17cac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 12 Jul 2016 10:39:19 +0200 Subject: [PATCH 03/12] Pythagorean tree: Make pylint happier --- Orange/widgets/visualize/owpythagorastree.py | 6 +-- .../widgets/visualize/pythagorastreeviewer.py | 4 +- .../visualize/widgetutils/common/view.py | 44 ++++++++-------- .../visualize/widgetutils/tree/rules.py | 52 ++++++++++--------- .../widgetutils/tree/skltreeadapter.py | 4 +- .../widgetutils/tree/tests/test_rules.py | 38 +++++++------- 6 files changed, 74 insertions(+), 74 deletions(-) diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 2ce4972e22a..4da683dcdca 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -294,10 +294,8 @@ def _update_legend_visibility(self): def _update_log_scale_slider(self): """On calc method combo box changed.""" - if self.SIZE_CALCULATION[self.size_calc_idx][0] == 'Logarithmic': - self.log_scale_box.parent().setEnabled(True) - else: - self.log_scale_box.parent().setEnabled(False) + self.log_scale_box.parent().setEnabled( + self.SIZE_CALCULATION[self.size_calc_idx][0] == 'Logarithmic') # MODEL REMOVED CONTROL ELEMENTS CLEAR METHODS def _clear_info_box(self): diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py index 42e79119ac2..6fa6ef0f2d8 100644 --- a/Orange/widgets/visualize/pythagorastreeviewer.py +++ b/Orange/widgets/visualize/pythagorastreeviewer.py @@ -442,7 +442,7 @@ def __init__(self, tree_node, parent=None, **kwargs): self.timer.setSingleShot(True) - def hoverEnterEvent(self, ev): + def hoverEnterEvent(self, event): self.timer.stop() def fnc(graphics_item): @@ -466,7 +466,7 @@ def other_fnc(graphics_item): self._propagate_z_values(self, fnc, other_fnc) - def hoverLeaveEvent(self, ev): + def hoverLeaveEvent(self, event): def fnc(graphics_item): # No need to set opacity in this branch since it was just selected diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/widgetutils/common/view.py index e080828f757..1803c9512f2 100644 --- a/Orange/widgets/visualize/widgetutils/common/view.py +++ b/Orange/widgets/visualize/widgetutils/common/view.py @@ -44,35 +44,35 @@ def __init__(self, scene, padding=(0, 0), **kwargs): super().__init__(scene, **kwargs) - def resizeEvent(self, ev): - super().resizeEvent(ev) + def resizeEvent(self, event): + super().resizeEvent(event) self.__needs_to_recalculate_initial = True - def wheelEvent(self, ev): - self.__handle_zoom(ev.delta()) - super().wheelEvent(ev) + def wheelEvent(self, event): + self.__handle_zoom(event.delta()) + super().wheelEvent(event) - def mousePressEvent(self, ev): + def mousePressEvent(self, event): # right click resets the zoom factor - if ev.button() == Qt.RightButton: + if event.button() == Qt.RightButton: self.reset_zoom() - super().mousePressEvent(ev) + super().mousePressEvent(event) - def keyPressEvent(self, ev): - if ev.key() == Qt.Key_Plus: + def keyPressEvent(self, event): + if event.key() == Qt.Key_Plus: self.__handle_zoom(1) - elif ev.key() == Qt.Key_Minus: + elif event.key() == Qt.Key_Minus: self.__handle_zoom(-1) - super().keyPressEvent(ev) + super().keyPressEvent(event) def __set_padding(self, padding): # Allow for multiple formats of padding for convenience if isinstance(padding, int): - padding = list(repeat(padding, 4)) + padding = tuple(repeat(padding, 4)) elif isinstance(padding, list) or isinstance(padding, tuple): if len(padding) == 2: - padding = (*padding, *padding) + padding = tuple(padding * 2) else: padding = 0, 0, 0, 0 @@ -99,8 +99,8 @@ def __handle_zoom(self, direction): def __zooming_out(direction): return direction < 0 - def __zooming_in(self, ev): - return not self.__zooming_out(ev) + def __zooming_in(self, event): + return not self.__zooming_out(event) def __reset_zoomout_limit(self): self.__zoomout_limit_reached = False @@ -166,15 +166,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setDragMode(QtGui.QGraphicsView.ScrollHandDrag) - def enterEvent(self, ev): + def enterEvent(self, event): self.viewport().setCursor(Qt.ArrowCursor) - super().enterEvent(ev) + super().enterEvent(event) - def mouseReleaseEvent(self, ev): - super().mouseReleaseEvent(ev) + def mouseReleaseEvent(self, event): + super().mouseReleaseEvent(event) self.viewport().setCursor(Qt.ArrowCursor) class PreventDefaultWheelEvent(QtGui.QGraphicsView): - def wheelEvent(self, ev): - ev.accept() + def wheelEvent(self, event): + event.accept() diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py index c113e096f55..587226a23d1 100644 --- a/Orange/widgets/visualize/widgetutils/tree/rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/rules.py @@ -1,3 +1,6 @@ +import warnings + + class Rule: """The base Rule class for tree rules.""" @@ -13,7 +16,7 @@ def merge_with(self, rule): Rule """ - raise NotImplemented() + raise NotImplementedError() class DiscreteRule(Rule): @@ -30,32 +33,31 @@ class DiscreteRule(Rule): -------- Age = 30 >>> rule = DiscreteRule('age', True, 30) + Name ≠ John >>> rule = DiscreteRule('name', False, 'John') Notes ----- - - Merging discrete rules is currently not implemented, the new rule is - simply returned and a warning is printed to stderr. + .. Note:: Merging discrete rules is currently not implemented, the new rule + is simply returned and a warning is issued. """ def __init__(self, attr_name, eq, value): self.attr_name = attr_name - self.sign = eq + self.eq = eq self.value = value def merge_with(self, rule): # It does not make sense to merge discrete rules, since they can only # be eq or not eq. - from sys import stderr - print('WARNING: Merged two discrete rules `%s` and `%s`' - % (self, rule), file=stderr) + warnings.warn('Merged two discrete rules `%s` and `%s`' % (self, rule)) return rule def __str__(self): return '{} {} {}'.format( - self.attr_name, '=' if self.sign else '≠', self.value) + self.attr_name, '=' if self.eq else '≠', self.value) class ContinuousRule(Rule): @@ -75,45 +77,45 @@ class ContinuousRule(Rule): -------- x ≤ 30 >>> rule = ContinuousRule('age', False, 30, inclusive=True) + x > 30 >>> rule = ContinuousRule('age', True, 30) Notes ----- - - Continuous rules can currently only be merged with other continuous - rules. + .. Note:: Continuous rules can currently only be merged with other + continuous rules. """ def __init__(self, attr_name, gt, value, inclusive=False): self.attr_name = attr_name - self.sign = gt + self.gt = gt self.value = value self.inclusive = inclusive def merge_with(self, rule): if not isinstance(rule, ContinuousRule): - raise NotImplemented('Continuous rules can currently only be ' - 'merged with other continuous rules') + raise NotImplementedError('Continuous rules can currently only be ' + 'merged with other continuous rules') # Handle when both have same sign - if self.sign == rule.sign: + if self.gt == rule.gt: # When both are GT - if self.sign is True: - larger = self.value if self.value > rule.value else rule.value - return ContinuousRule(self.attr_name, self.sign, larger) + if self.gt is True: + larger = max(self.value, rule.value) + return ContinuousRule(self.attr_name, self.gt, larger) # When both are LT else: smaller = self.value if self.value < rule.value else rule.value - return ContinuousRule(self.attr_name, self.sign, smaller) + return ContinuousRule(self.attr_name, self.gt, smaller) # When they have different signs we need to return an interval rule else: - lt_rule = self if self.sign is False else rule - gt_rule = self if lt_rule != self else rule + lt_rule, gt_rule = (rule, self) if self.gt else (self, rule) return IntervalRule(self.attr_name, gt_rule, lt_rule) def __str__(self): return '%s %s %.3f' % ( - self.attr_name, '>' if self.sign else '≤', self.value) + self.attr_name, '>' if self.gt else '≤', self.value) class IntervalRule(Rule): @@ -136,9 +138,9 @@ class IntervalRule(Rule): Notes ----- - - Currently, only cases which appear in classification and regression - trees are implemented. An interval can not be made up of two parts - (e.g. (-∞, -1) ∪ (1, ∞)). + .. Note:: Currently, only cases which appear in classification and + regression trees are implemented. An interval can not be made up of two + parts (e.g. (-∞, -1) ∪ (1, ∞)). """ @@ -158,7 +160,7 @@ def __init__(self, attr_name, left_rule, right_rule): def merge_with(self, rule): if isinstance(rule, ContinuousRule): - if rule.sign: + if rule.gt: return IntervalRule( self.attr_name, self.left_rule.merge_with(rule), self.right_rule) diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py index 307fad83d1c..759542cea18 100644 --- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py +++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py @@ -204,7 +204,7 @@ def find_largest_idx(n): l_node = find_largest_idx(self._tree.children_left[n]) r_node = find_largest_idx(self._tree.children_right[n]) - return l_node if l_node > r_node else r_node + return max(l_node, r_node) right = left = node if self._tree.children_left[left] == self.NO_CHILD: @@ -286,7 +286,7 @@ def get_instances_in_nodes(self, dataset, nodes): all_leaves = self.leaves(self.root) - indices = np.searchsorted(all_leaves, node_leaves, side='left') + indices = np.searchsorted(all_leaves, node_leaves) # all the leaf samples for each leaf leaf_samples = self.get_samples_in_leaves(dataset.X) # filter out the leaf samples array that are not selected diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py index a6a4a86e0e4..0a8633b26f6 100644 --- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py @@ -42,8 +42,8 @@ def test_merging_lt_with_gt_continuous_rules(self): rule2 = ContinuousRule('Rule', False, 2) new_rule = rule1.merge_with(rule2) self.assertIsInstance(new_rule, IntervalRule) - self.assertEquals(new_rule.left_rule, rule1) - self.assertEquals(new_rule.right_rule, rule2) + self.assertEqual(new_rule.left_rule, rule1) + self.assertEqual(new_rule.right_rule, rule2) # INTERVAL RULES def test_merging_interval_rule_with_smaller_continuous_rule(self): @@ -54,7 +54,7 @@ def test_merging_interval_rule_with_smaller_continuous_rule(self): rule2 = ContinuousRule('Rule', False, 2) new_rule = rule1.merge_with(rule2) self.assertIsInstance(new_rule, IntervalRule) - self.assertEquals(new_rule.right_rule.value, 2) + self.assertEqual(new_rule.right_rule.value, 2) def test_merging_interval_rule_with_larger_continuous_rule(self): """Merging `1 < x < 2` and `x < 3` should produce `1 < x < 2`.""" @@ -64,7 +64,7 @@ def test_merging_interval_rule_with_larger_continuous_rule(self): rule2 = ContinuousRule('Rule', False, 3) new_rule = rule1.merge_with(rule2) self.assertIsInstance(new_rule, IntervalRule) - self.assertEquals(new_rule.left_rule.value, 1) + self.assertEqual(new_rule.left_rule.value, 1) def test_merging_interval_rule_with_larger_lt_continuous_rule(self): """Merging `0 < x < 3` and `x > 1` should produce `1 < x < 3`.""" @@ -74,7 +74,7 @@ def test_merging_interval_rule_with_larger_lt_continuous_rule(self): rule2 = ContinuousRule('Rule', True, 1) new_rule = rule1.merge_with(rule2) self.assertIsInstance(new_rule, IntervalRule) - self.assertEquals(new_rule.left_rule.value, 1) + self.assertEqual(new_rule.left_rule.value, 1) def test_merging_interval_rule_with_smaller_gt_continuous_rule(self): """Merging `0 < x < 3` and `x < 2` should produce `0 < x < 2`.""" @@ -84,7 +84,7 @@ def test_merging_interval_rule_with_smaller_gt_continuous_rule(self): rule2 = ContinuousRule('Rule', False, 2) new_rule = rule1.merge_with(rule2) self.assertIsInstance(new_rule, IntervalRule) - self.assertEquals(new_rule.right_rule.value, 2) + self.assertEqual(new_rule.right_rule.value, 2) def test_merging_interval_rules_with_smaller_lt_component(self): """Merging `1 < x < 2` and `0 < x < 2` should produce `1 < x < 2`.""" @@ -95,8 +95,8 @@ def test_merging_interval_rules_with_smaller_lt_component(self): ContinuousRule('Rule', True, 0), ContinuousRule('Rule', False, 2)) new_rule = rule1.merge_with(rule2) - self.assertEquals(new_rule.left_rule.value, 1) - self.assertEquals(new_rule.right_rule.value, 2) + self.assertEqual(new_rule.left_rule.value, 1) + self.assertEqual(new_rule.right_rule.value, 2) def test_merging_interval_rules_with_larger_lt_component(self): """Merging `0 < x < 4` and `1 < x < 4` should produce `1 < x < 4`.""" @@ -107,8 +107,8 @@ def test_merging_interval_rules_with_larger_lt_component(self): ContinuousRule('Rule', True, 1), ContinuousRule('Rule', False, 4)) new_rule = rule1.merge_with(rule2) - self.assertEquals(new_rule.left_rule.value, 1) - self.assertEquals(new_rule.right_rule.value, 4) + self.assertEqual(new_rule.left_rule.value, 1) + self.assertEqual(new_rule.right_rule.value, 4) def test_merging_interval_rules_generally(self): """Merging `0 < x < 4` and `2 < x < 6` should produce `2 < x < 4`.""" @@ -119,8 +119,8 @@ def test_merging_interval_rules_generally(self): ContinuousRule('Rule', True, 2), ContinuousRule('Rule', False, 6)) new_rule = rule1.merge_with(rule2) - self.assertEquals(new_rule.left_rule.value, 2) - self.assertEquals(new_rule.right_rule.value, 4) + self.assertEqual(new_rule.left_rule.value, 2) + self.assertEqual(new_rule.right_rule.value, 4) # ALL RULES def test_merge_commutativity_on_continuous_rules(self): @@ -139,19 +139,19 @@ def test_merge_commutativity_on_interval_rules(self): ContinuousRule('Rule', False, 6)) new_rule1 = rule1.merge_with(rule2) new_rule2 = rule2.merge_with(rule1) - self.assertEquals(new_rule1.left_rule.value, - new_rule2.left_rule.value) - self.assertEquals(new_rule1.right_rule.value, - new_rule2.right_rule.value) + self.assertEqual(new_rule1.left_rule.value, + new_rule2.left_rule.value) + self.assertEqual(new_rule1.right_rule.value, + new_rule2.right_rule.value) - def test_merge_keeps_sign_on_continuous_rules(self): + def test_merge_keeps_gt_on_continuous_rules(self): rule1 = ContinuousRule('Rule1', True, 1) rule2 = ContinuousRule('Rule1', True, 2) new_rule = rule1.merge_with(rule2) - self.assertEquals(new_rule.sign, True) + self.assertEqual(new_rule.gt, True) def test_merge_keeps_attr_name_on_continuous_rules(self): rule1 = ContinuousRule('Rule1', True, 1) rule2 = ContinuousRule('Rule1', True, 2) new_rule = rule1.merge_with(rule2) - self.assertEquals(new_rule.attr_name, 'Rule1') + self.assertEqual(new_rule.attr_name, 'Rule1') From 7ef0d70540c76b62b62e3da27d5fea06794792c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 12 Jul 2016 12:07:13 +0200 Subject: [PATCH 04/12] Pythagorean tree: Add docstrings and reduce pylint errors --- Orange/misc/cache.py | 1 + Orange/widgets/visualize/owpythagorastree.py | 18 +-- .../widgets/visualize/owpythagoreanforest.py | 10 ++ .../widgets/visualize/pythagorastreeviewer.py | 13 ++- .../visualize/tests/test_owpythagorastree.py | 17 ++- .../visualize/widgetutils/common/owgrid.py | 13 ++- .../visualize/widgetutils/common/owlegend.py | 109 +++++++++++------- .../visualize/widgetutils/common/scene.py | 5 +- .../visualize/widgetutils/common/view.py | 28 ++++- .../visualize/widgetutils/tree/rules.py | 47 +++++--- .../widgetutils/tree/skltreeadapter.py | 1 + .../widgetutils/tree/tests/test_rules.py | 12 ++ .../visualize/widgetutils/tree/treeadapter.py | 1 + 13 files changed, 198 insertions(+), 77 deletions(-) diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py index 2dc9f39751d..6d2d3f3baa7 100644 --- a/Orange/misc/cache.py +++ b/Orange/misc/cache.py @@ -1,3 +1,4 @@ +"""Common caching methods, using `lru_cahce` sometimes has its downsides.""" from functools import wraps, lru_cache import weakref diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 4da683dcdca..0ba5b48306c 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -397,8 +397,7 @@ def _classification_update_legend_colors(self): else: items = ( (self.target_class_combo.itemText(self.target_class_index), - self.color_palette[self.target_class_index - 1] - ), + self.color_palette[self.target_class_index - 1]), ('other', QtGui.QColor('#ffffff')) ) self.legend = OWDiscreteLegend(items=items, **self.LEGEND_OPTIONS) @@ -450,12 +449,12 @@ def _classification_get_tooltip(self, node): return '

' \ + text \ + '{}/{} samples ({:2.3f}%)'.format( - int(samples), total, ratio * 100) \ + int(samples), total, ratio * 100) \ + '


' \ + ('Split by ' + splitting_attr.name - if not self.tree_adapter.is_leaf(node.label) else '') \ + if not self.tree_adapter.is_leaf(node.label) else '') \ + ('

' if len(rules) and not self.tree_adapter.is_leaf( - node.label) else '') \ + node.label) else '') \ + rules_str \ + '

' @@ -556,12 +555,12 @@ def _regression_get_tooltip(self, node): return '

Mean: {:2.3f}'.format(mean) \ + '
Standard deviation: {:2.3f}'.format(std) \ + '
{}/{} samples ({:2.3f}%)'.format( - int(samples), total, ratio * 100) \ + int(samples), total, ratio * 100) \ + '


' \ + ('Split by ' + splitting_attr.name - if not self.tree_adapter.is_leaf(node.label) else '') \ + if not self.tree_adapter.is_leaf(node.label) else '') \ + ('

' if len(rules) and not self.tree_adapter.is_leaf( - node.label) else '') \ + node.label) else '') \ + rules_str \ + '

' @@ -572,10 +571,13 @@ class TreeGraphicsView( AnchorableGraphicsView, PreventDefaultWheelEvent ): + """QGraphicsView that contains all functionality we will use to display + tree.""" pass class TreeGraphicsScene(UpdateItemsOnSelectGraphicsScene): + """QGraphicsScene that the tree uses.""" pass diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index d007aa61eaf..e53b9c25c6c 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -1,3 +1,4 @@ +"""Pythagorean forest widget for visualizing random forests.""" from math import log, sqrt import numpy as np @@ -163,14 +164,17 @@ def clear(self): # CONTROL AREA CALLBACKS def max_depth_changed(self): + """When the max depth slider is changed.""" for tree in self.ptrees: tree.set_depth_limit(self.depth_limit) def target_colors_changed(self): + """When the target class or coloring method is changed.""" for tree in self.ptrees: tree.target_class_has_changed() def size_calc_changed(self): + """When the size calculation of the trees is changed.""" if self.model is not None: self.forest_adapter = self._get_forest_adapter(self.model) self.grid.clear() @@ -181,6 +185,7 @@ def size_calc_changed(self): self.max_depth_changed() def zoom_changed(self): + """When we update the "Zoom" slider.""" for item in self.grid_items: item.set_max_size(self._calculate_zoom(self.zoom)) @@ -385,10 +390,13 @@ def _color_stddev(self, adapter, tree_node): class GridItem(SelectableGridItem, ZoomableGridItem): + """The grid item we will use in our grid.""" pass class SklRandomForestAdapter: + """Take a `RandomForest` and wrap all the trees into the `TreeAdapter` + instances that Pythagorean trees use.""" def __init__(self, model, domain, adjust_weight=lambda x: x): self._adapters = [] @@ -399,6 +407,7 @@ def __init__(self, model, domain, adjust_weight=lambda x: x): self._adjust_weight = adjust_weight def get_trees(self): + """Get the tree adapters in the random forest.""" if len(self._adapters) > 0: return self._adapters if len(self._trees) < 1: @@ -412,4 +421,5 @@ def get_trees(self): @property def domain(self): + """Get the domain.""" return self._domain diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py index 6fa6ef0f2d8..41d4ac12d3e 100644 --- a/Orange/widgets/visualize/pythagorastreeviewer.py +++ b/Orange/widgets/visualize/pythagorastreeviewer.py @@ -21,8 +21,6 @@ from PyQt4 import QtCore, QtGui from PyQt4.QtCore import Qt -from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter - # z index range, increase if needed Z_STEP = 5000000 @@ -38,6 +36,9 @@ class PythagorasTreeViewer(QtGui.QGraphicsWidget): Examples -------- + >>> from Orange.widgets.visualize.widgetutils.tree.treeadapter import ( + >>> TreeAdapter + >>> ) Pass tree through constructor. >>> tree_view = PythagorasTreeViewer(parent=scene, adapter=tree_adapter) @@ -64,7 +65,7 @@ class PythagorasTreeViewer(QtGui.QGraphicsWidget): Notes ----- - .. Note:: The class contains two clear methods: `clear` and `clear_tree`. + .. note:: The class contains two clear methods: `clear` and `clear_tree`. Each has their own use. `clear_tree` will clear out the tree and remove any graphics items. `clear` will, on the other hand, clear everything, all settings @@ -195,10 +196,12 @@ def _get_tooltip(self, *args): return 'Tooltip' def target_class_has_changed(self): + """When the target class has changed, perform appropriate updates.""" self._update_node_colors() self._update_node_tooltips() def tooltip_has_changed(self): + """When the tooltip should change, perform appropriate updates.""" self._update_node_tooltips() def _update_node_colors(self): @@ -512,7 +515,7 @@ def _propagate_to_parents(self, graphics_item, fnc, other_fnc): self._propagate_to_parents(parent, fnc, other_fnc) def selection_changed(self): - # Handle selection changed + """Handle selection changed.""" self.any_selected = len(self.scene().selectedItems()) > 0 if self.any_selected: if self.isSelected(): @@ -554,7 +557,7 @@ class TreeNode: parent : TreeNode or object The parent of the current node. In the case of root, an object containing the root label of the tree adapter should be passed. - children : tuple of TreeNode, optional + children : tuple of TreeNode, optional, default is empty tuple All the children that belong to this node. """ diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py index c0f7d05ed40..4015bef0a3e 100644 --- a/Orange/widgets/visualize/tests/test_owpythagorastree.py +++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py @@ -1,6 +1,6 @@ +"""Tests for the Pythagorean tree widget and associated classes.""" import math import unittest -import Orange.widgets from Orange.widgets.visualize.pythagorastreeviewer import ( PythagorasTree, @@ -10,10 +10,17 @@ class TestPythagorasTree(unittest.TestCase): + """Pythagorean tree testing, make sure calculating square positions works + properly. + + Most of the data is non trivial since the rotations and translations don't + generally produce trivial results. + """ def setUp(self): self.builder = PythagorasTree() def test_get_point_on_square_edge_with_no_angle(self): + """Get central point on square edge that is not angled.""" point = self.builder._get_point_on_square_edge( center=Point(0, 0), length=2, angle=0 ) @@ -22,6 +29,7 @@ def test_get_point_on_square_edge_with_no_angle(self): self.assertAlmostEqual(point.y, expected_point.y, places=1) def test_get_point_on_square_edge_with_non_zero_angle(self): + """Get central point on square edge that has angle""" point = self.builder._get_point_on_square_edge( center=Point(2.7, 2.77), length=1.65, angle=math.radians(20.97) ) @@ -30,6 +38,8 @@ def test_get_point_on_square_edge_with_non_zero_angle(self): self.assertAlmostEqual(point.y, expected_point.y, places=1) def test_compute_center_with_simple_square_angle(self): + """Compute the center of the square in the next step given a right + angle.""" initial_square = Square(Point(0, 0), length=2, angle=math.pi / 2) point = self.builder._compute_center( initial_square, length=1.13, alpha=math.radians(68.57)) @@ -38,6 +48,8 @@ def test_compute_center_with_simple_square_angle(self): self.assertAlmostEqual(point.y, expected_point.y, places=1) def test_compute_center_with_complex_square_angle(self): + """Compute the center of the square in the next step given a more + complex angle.""" initial_square = Square( Point(1.5, 1.5), length=2.24, angle=math.radians(63.43) ) @@ -48,6 +60,9 @@ def test_compute_center_with_complex_square_angle(self): self.assertAlmostEqual(point.y, expected_point.y, places=1) def test_compute_center_with_complex_square_angle_with_base_angle(self): + """Compute the center of the square in the next step when there is a + base angle - when the square does not touch the base square on the left + edge.""" initial_square = Square( Point(1.5, 1.5), length=2.24, angle=math.radians(63.43) ) diff --git a/Orange/widgets/visualize/widgetutils/common/owgrid.py b/Orange/widgets/visualize/widgetutils/common/owgrid.py index 6e091288177..e6656e9ff40 100644 --- a/Orange/widgets/visualize/widgetutils/common/owgrid.py +++ b/Orange/widgets/visualize/widgetutils/common/owgrid.py @@ -1,3 +1,9 @@ +"""Grid widget. + +Positions items into a grid. This has been tested with widgets that have their +`boundingBox` and `sizeHint` methods properly defined. + +""" from itertools import zip_longest from PyQt4 import QtGui, QtCore @@ -84,11 +90,14 @@ def paint(self, painter, options, widget=None): class ZoomableGridItem(GridItem): """Makes a grid item "zoomable" through the `set_max_size` method. + "Zoomable" here means there is a `Zoom` slider through which the grid items + can be made larger and smaller in the grid. + Notes ----- - .. Note:: This grid item will override any bounding box or size hint + .. note:: This grid item will override any bounding box or size hint defined in the class hierarchy with its own. - .. Note:: This makes the grid item square. + .. note:: This makes the grid item square. Parameters ---------- diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/widgetutils/common/owlegend.py index 1047d7a43b0..47f90f27cea 100644 --- a/Orange/widgets/visualize/widgetutils/common/owlegend.py +++ b/Orange/widgets/visualize/widgetutils/common/owlegend.py @@ -1,12 +1,24 @@ -""" -Legend classes to use with `QGraphicsScene` objects. -""" +"""Legend classes to use with `QGraphicsScene` objects.""" import numpy as np from PyQt4 import QtGui, QtCore from PyQt4.QtCore import Qt class Anchorable(QtGui.QGraphicsWidget): + """Anchorable base class. + + Subclassing the `Anchorable` class will anchor the given + `QGraphicsWidget` to a position on the viewport. This does require you to + use the `AnchorableGraphicsView` class, it is made to be composable, so + that should not be a problem. + + Notes + ----- + .. note:: Subclassing this class will not make your widget movable, you + have to do that yourself. If you do make your widget movable, this will + handle any further positioning when the widget is moved. + + """ __corners = ['topLeft', 'topRight', 'bottomLeft', 'bottomRight'] TOP_LEFT, TOP_RIGHT, BOTTOM_LEFT, BOTTOM_RIGHT = __corners @@ -27,30 +39,30 @@ def __init__(self, parent=None, corner='bottomRight', offset=(10, 10)): elif isinstance(offset, QtCore.QPoint): self.__offset = offset - def moveEvent(self, ev): - super().moveEvent(ev) + def moveEvent(self, event): + super().moveEvent(event) # This check is needed because simply resizing the window will cause # the item to move and trigger a `moveEvent` therefore we need to check # that the movement was done intentionally by the user using the mouse if QtGui.QApplication.mouseButtons() == Qt.LeftButton: self.recalculate_offset() - def resizeEvent(self, ev): + def resizeEvent(self, event): # When the item is first shown, we need to update its position - super().resizeEvent(ev) + super().resizeEvent(event) if not self.__has_been_drawn: self.__offset = self.__calculate_actual_offset(self.__offset) self.update_pos() self.__has_been_drawn = True - def showEvent(self, ev): + def showEvent(self, event): # When the item is first shown, we need to update its position - super().showEvent(ev) + super().showEvent(event) self.update_pos() def recalculate_offset(self): - # This is called whenever the item is being moved and needs to - # recalculate its offset + """This is called whenever the item is being moved and needs to + recalculate its offset.""" view = self.__get_view() # Get the view box and position of legend relative to the view, # not the scene @@ -63,9 +75,11 @@ def recalculate_offset(self): self.__offset = viewbox_corner - pos def update_pos(self): - # This is called whenever something happened with the view that caused - # this item to move from its anchored position, so we have to adjust - # the position to maintain the effect of being anchored + """Update the widget position relative to the viewport. + + This is called whenever something happened with the view that caused + this item to move from its anchored position, so we have to adjust the + position to maintain the effect of being anchored.""" view = self.__get_view() if self.__corner_str and view is not None: box = self.__usable_viewbox() @@ -125,45 +139,48 @@ def __usable_viewbox(self): view = self.__get_view() if view.horizontalScrollBar().isVisible(): - h = view.horizontalScrollBar().size().height() + height = view.horizontalScrollBar().size().height() else: - h = 0 + height = 0 if view.verticalScrollBar().isVisible(): - w = view.verticalScrollBar().size().width() + width = view.verticalScrollBar().size().width() else: - w = 0 + width = 0 - size = view.size() - QtCore.QSize(w, h) + size = view.size() - QtCore.QSize(width, height) return QtCore.QRect(QtCore.QPoint(0, 0), size) class AnchorableGraphicsView(QtGui.QGraphicsView): + """Subclass when wanting to use Anchorable items in your view.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Handle scroll bar hiding or showing self.horizontalScrollBar().valueChanged.connect( self.update_anchored_items) self.verticalScrollBar().valueChanged.connect( self.update_anchored_items) - def resizeEvent(self, ev): - super().resizeEvent(ev) + def resizeEvent(self, event): + super().resizeEvent(event) self.update_anchored_items() - def mousePressEvent(self, ev): - super().mousePressEvent(ev) + def mousePressEvent(self, event): + super().mousePressEvent(event) self.update_anchored_items() - def wheelEvent(self, ev): - super().wheelEvent(ev) + def wheelEvent(self, event): + super().wheelEvent(event) self.update_anchored_items() - def mouseMoveEvent(self, ev): - super().mouseMoveEvent(ev) + def mouseMoveEvent(self, event): + super().mouseMoveEvent(event) self.update_anchored_items() def update_anchored_items(self): + """Update all the items that subclass the `Anchorable` class.""" for item in self.__anchorable_items(): item.update_pos() @@ -172,6 +189,9 @@ def __anchorable_items(self): class ColorIndicator(QtGui.QGraphicsWidget): + """Base class for an item indicator. + + Usually the little square or circle in the legend in front of the text.""" pass @@ -320,7 +340,7 @@ class LegendGradient(QtGui.QGraphicsWidget): Notes ----- - .. Note:: While the gradient does support any number of colors, any more + .. note:: While the gradient does support any number of colors, any more than 3 is not very readable. This should not be a problem, since Orange only implements 2 or 3 colors. @@ -348,11 +368,11 @@ def __init__(self, palette, parent, orientation): # Get the appropriate rectangle dimensions based on orientation if orientation == Qt.Vertical: - w, h = self.GRADIENT_WIDTH, self.GRADIENT_HEIGHT + width, height = self.GRADIENT_WIDTH, self.GRADIENT_HEIGHT elif orientation == Qt.Horizontal: - w, h = self.GRADIENT_HEIGHT, self.GRADIENT_WIDTH + width, height = self.GRADIENT_HEIGHT, self.GRADIENT_WIDTH - self.__rect_item = QtGui.QGraphicsRectItem(0, 0, w, h, self) + self.__rect_item = QtGui.QGraphicsRectItem(0, 0, width, height, self) self.__rect_item.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0))) self.__rect_item.setBrush(QtGui.QBrush(self.__gradient)) @@ -418,13 +438,10 @@ def _format_values(values): class Legend(Anchorable): """Base legend class. - This class provides common attributes for any legend derivates: + This class provides common attributes for any legend subclasses: - Behaviour on `QGraphicsScene` - Appearance of legend - If you have access to the `domain` property, the `LegendBuilder` class - can be used to automatically build a legend for you. - Parameters ---------- parent : QtGui.QGraphicsItem, optional @@ -447,7 +464,7 @@ class Legend(Anchorable): Notes ----- - .. Warning:: If the domain parameter is supplied, the items parameter will + .. warning:: If the domain parameter is supplied, the items parameter will be ignored. """ @@ -457,6 +474,7 @@ def __init__(self, parent=None, orientation=Qt.Vertical, domain=None, font=None, color_indicator_cls=LegendItemSquare, **kwargs): super().__init__(parent, **kwargs) + self._layout = None self.orientation = orientation self.bg_color = QtGui.QBrush(bg_color) self.color_indicator_cls = color_indicator_cls @@ -510,7 +528,7 @@ def set_domain(self, domain): If the domain does not contain the correct type of class variable. """ - raise NotImplemented() + raise NotImplementedError() def set_items(self, values): """Handle receiving an array of items. @@ -523,7 +541,7 @@ def set_items(self, values): ------- """ - raise NotImplemented() + raise NotImplementedError() @staticmethod def _convert_to_color(obj): @@ -587,7 +605,7 @@ class OWContinuousLegend(Legend): OWDiscreteLegend """ - + def __init__(self, *args, **kwargs): # Variables used in the `set_` methods must be set before calling super self.__range = kwargs.get('range', ()) @@ -637,6 +655,19 @@ def set_items(self, values): class OWBinnedContinuousLegend(Legend): + """Binned continuous legend in case you don't like gradients. + + This is not implemented yet, but in case it ever needs to be, the stub is + available. + + See Also + -------- + Legend + OWDiscreteLegend + OWContinuousLegend + + """ + def set_domain(self, domain): pass diff --git a/Orange/widgets/visualize/widgetutils/common/scene.py b/Orange/widgets/visualize/widgetutils/common/scene.py index ca10a1993fe..7c68f146e83 100644 --- a/Orange/widgets/visualize/widgetutils/common/scene.py +++ b/Orange/widgets/visualize/widgetutils/common/scene.py @@ -1,3 +1,4 @@ +"""Common QGraphicsScene components that can be composed when needed.""" from PyQt4 import QtGui @@ -9,8 +10,8 @@ class UpdateItemsOnSelectGraphicsScene(QtGui.QGraphicsScene): Notes ----- - ..Note:: I suspect this is completely unncessary, but have not been able to - find a reasonable way to keep the selection logic inside the actual + .. note:: I suspect this is completely unncessary, but have not been able + to find a reasonable way to keep the selection logic inside the actual `QGraphicsItem` objects """ diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/widgetutils/common/view.py index 1803c9512f2..e318e510bdb 100644 --- a/Orange/widgets/visualize/widgetutils/common/view.py +++ b/Orange/widgets/visualize/widgetutils/common/view.py @@ -1,3 +1,5 @@ +"""Common useful `QGraphicsView` classes that can be composed to achieve +desired functionality.""" from itertools import repeat import numpy as np @@ -25,7 +27,10 @@ class ZoomableGraphicsView(QtGui.QGraphicsView): Notes ----- - - This view will consume wheel scrolling and right mouse click events. + .. note:: This view will NOT consume the wheel event, so it would be wise + to use this component in conjuction with the `PreventDefaultWheelEvent` + in most cases. + .. note:: This view does however consume the right mouse click event. """ @@ -76,8 +81,8 @@ def __set_padding(self, padding): else: padding = 0, 0, 0, 0 - l, t, r, b = padding - self.__padding = -l, -t, r, b + left, top, right, bottom = padding + self.__padding = -left, -top, right, bottom def __handle_zoom(self, direction): """Handle zoom event, direction is positive if zooming in, otherwise @@ -106,6 +111,16 @@ def __reset_zoomout_limit(self): self.__zoomout_limit_reached = False def set_central_widget(self, widget): + """Set the central widget in the view. + + This means that the initial zoom will fit the central widget, and may + cut out any other widgets. + + Parameters + ---------- + widget : QGraphicsWidget + + """ self.__central_widget = widget def central_widget_rect(self): @@ -176,5 +191,12 @@ def mouseReleaseEvent(self, event): class PreventDefaultWheelEvent(QtGui.QGraphicsView): + """Prevent the default wheel event. + + The default wheel event pans the view around, if using the + `ZoomableGraphicsView`, this will prevent that behaviour. + + """ + def wheelEvent(self, event): event.accept() diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py index 587226a23d1..5be8a6262d3 100644 --- a/Orange/widgets/visualize/widgetutils/tree/rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/rules.py @@ -1,3 +1,16 @@ +"""Rules for classification and regression trees. + +Tree visualisations usually need to show the rules of nodes, these classes make +merging these rules simple (otherwise you have repeating rules e.g. `age < 3` +and `age < 2` which can be merged into `age < 2`. + +Subclasses of the `Rule` class should provide a nice interface to merge rules +together through the `merge_with` method. Of course, this should not be forced +where it doesn't make sense e.g. merging a discrete rule (e.g. +:math:`x \in \{red, blue, green\}`) and a continuous rule (e.g. +:math:`x \leq 5`). + +""" import warnings @@ -25,7 +38,7 @@ class DiscreteRule(Rule): Parameters ---------- attr_name : str - eq : bool + equals : bool Should indicate whether or not the rule equals the value or not. value : object @@ -39,14 +52,14 @@ class DiscreteRule(Rule): Notes ----- - .. Note:: Merging discrete rules is currently not implemented, the new rule + .. note:: Merging discrete rules is currently not implemented, the new rule is simply returned and a warning is issued. """ - def __init__(self, attr_name, eq, value): + def __init__(self, attr_name, equals, value): self.attr_name = attr_name - self.eq = eq + self.equals = equals self.value = value def merge_with(self, rule): @@ -57,7 +70,7 @@ def merge_with(self, rule): def __str__(self): return '{} {} {}'.format( - self.attr_name, '=' if self.eq else '≠', self.value) + self.attr_name, '=' if self.equals else '≠', self.value) class ContinuousRule(Rule): @@ -66,7 +79,7 @@ class ContinuousRule(Rule): Parameters ---------- attr_name : str - gt : bool + greater : bool Should indicate whether the variable must be greater than the value. value : int inclusive : bool, optional @@ -83,14 +96,14 @@ class ContinuousRule(Rule): Notes ----- - .. Note:: Continuous rules can currently only be merged with other + .. note:: Continuous rules can currently only be merged with other continuous rules. """ - def __init__(self, attr_name, gt, value, inclusive=False): + def __init__(self, attr_name, greater, value, inclusive=False): self.attr_name = attr_name - self.gt = gt + self.greater = greater self.value = value self.inclusive = inclusive @@ -99,23 +112,23 @@ def merge_with(self, rule): raise NotImplementedError('Continuous rules can currently only be ' 'merged with other continuous rules') # Handle when both have same sign - if self.gt == rule.gt: + if self.greater == rule.greater: # When both are GT - if self.gt is True: + if self.greater is True: larger = max(self.value, rule.value) - return ContinuousRule(self.attr_name, self.gt, larger) + return ContinuousRule(self.attr_name, self.greater, larger) # When both are LT else: smaller = self.value if self.value < rule.value else rule.value - return ContinuousRule(self.attr_name, self.gt, smaller) + return ContinuousRule(self.attr_name, self.greater, smaller) # When they have different signs we need to return an interval rule else: - lt_rule, gt_rule = (rule, self) if self.gt else (self, rule) + lt_rule, gt_rule = (rule, self) if self.greater else (self, rule) return IntervalRule(self.attr_name, gt_rule, lt_rule) def __str__(self): return '%s %s %.3f' % ( - self.attr_name, '>' if self.gt else '≤', self.value) + self.attr_name, '>' if self.greater else '≤', self.value) class IntervalRule(Rule): @@ -138,7 +151,7 @@ class IntervalRule(Rule): Notes ----- - .. Note:: Currently, only cases which appear in classification and + .. note:: Currently, only cases which appear in classification and regression trees are implemented. An interval can not be made up of two parts (e.g. (-∞, -1) ∪ (1, ∞)). @@ -160,7 +173,7 @@ def __init__(self, attr_name, left_rule, right_rule): def merge_with(self, rule): if isinstance(rule, ContinuousRule): - if rule.gt: + if rule.greater: return IntervalRule( self.attr_name, self.left_rule.merge_with(rule), self.right_rule) diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py index 759542cea18..980811decf3 100644 --- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py +++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py @@ -1,3 +1,4 @@ +"""Tree adapter class for sklearn trees.""" from collections import OrderedDict import numpy as np diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py index 0a8633b26f6..abf64c727de 100644 --- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py @@ -1,3 +1,4 @@ +"""Test rules for classification and regression trees.""" import unittest from Orange.widgets.visualize.widgetutils.tree.rules import ( @@ -7,6 +8,13 @@ class TestRules(unittest.TestCase): + """Rules for classification and regression trees. + + See Also + -------- + Orange.widgets.visualize.widgetutils.tree.rules + + """ # CONTINUOUS RULES def test_merging_two_gt_continuous_rules(self): """Merging `x > 1` and `x > 2` should produce `x > 2`.""" @@ -124,6 +132,7 @@ def test_merging_interval_rules_generally(self): # ALL RULES def test_merge_commutativity_on_continuous_rules(self): + """Continuous rule merging should be commutative.""" rule1 = ContinuousRule('Rule1', True, 1) rule2 = ContinuousRule('Rule1', True, 2) new_rule1 = rule1.merge_with(rule2) @@ -131,6 +140,7 @@ def test_merge_commutativity_on_continuous_rules(self): self.assertEqual(new_rule1.value, new_rule2.value) def test_merge_commutativity_on_interval_rules(self): + """Interval rule merging should be commutative.""" rule1 = IntervalRule('Rule', ContinuousRule('Rule', True, 0), ContinuousRule('Rule', False, 4)) @@ -145,12 +155,14 @@ def test_merge_commutativity_on_interval_rules(self): new_rule2.right_rule.value) def test_merge_keeps_gt_on_continuous_rules(self): + """Merging ccontinuous rules should keep GT property.""" rule1 = ContinuousRule('Rule1', True, 1) rule2 = ContinuousRule('Rule1', True, 2) new_rule = rule1.merge_with(rule2) self.assertEqual(new_rule.gt, True) def test_merge_keeps_attr_name_on_continuous_rules(self): + """Merging continuous rules should keep the name of the rule.""" rule1 = ContinuousRule('Rule1', True, 1) rule2 = ContinuousRule('Rule1', True, 2) new_rule = rule1.merge_with(rule2) diff --git a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py index 7bf07368c5f..8d78b392a27 100644 --- a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py +++ b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py @@ -1,3 +1,4 @@ +"""Base tree adapter class with common methods needed for visualisations.""" from abc import ABCMeta, abstractmethod From 0732d822487add9f675c4c073a9edd46c4cdf185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 12 Jul 2016 12:24:39 +0200 Subject: [PATCH 05/12] Pythagorean tree: Add more pylint fixes --- Orange/misc/cache.py | 12 +++++----- Orange/widgets/visualize/owpythagorastree.py | 22 ++++++++++++------- .../widgets/visualize/owpythagoreanforest.py | 1 + .../visualize/tests/test_owpythagorastree.py | 1 + .../visualize/widgetutils/common/owlegend.py | 14 +++++++----- .../visualize/widgetutils/tree/rules.py | 2 +- 6 files changed, 32 insertions(+), 20 deletions(-) diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py index 6d2d3f3baa7..adc10f5c5b2 100644 --- a/Orange/misc/cache.py +++ b/Orange/misc/cache.py @@ -1,24 +1,26 @@ -"""Common caching methods, using `lru_cahce` sometimes has its downsides.""" +"""Common caching methods, using `lru_cache` sometimes has its downsides.""" from functools import wraps, lru_cache import weakref -def single_cache(f): +def single_cache(func): + """Cache with size 1.""" last_args = () last_kwargs = set() last_result = None - def cached(*args, **kwargs): + @wraps(func) + def _cached(*args, **kwargs): nonlocal last_args, last_kwargs, last_result if len(last_args) != len(args) or \ not all(x is y for x, y in zip(args, last_args)) or \ last_kwargs != set(kwargs) or \ any(last_kwargs[k] != kwargs[k] for k in last_kwargs): - last_result = f(*args, **kwargs) + last_result = func(*args, **kwargs) last_args, last_kwargs = args, kwargs return last_result - return cached + return _cached def memoize_method(*lru_args, **lru_kwargs): diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 0ba5b48306c..934a0808904 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -108,7 +108,7 @@ def __init__(self): # CONTROL AREA # Tree info area box_info = gui.widgetBox(self.controlArea, 'Tree Info') - self.info = gui.widgetLabel(box_info, label='') + self.info = gui.widgetLabel(box_info) # Display settings area box_display = gui.widgetBox(self.controlArea, 'Display Settings') @@ -249,10 +249,12 @@ def update_depth(self): self.ptree.set_depth_limit(self.depth_limit) def update_colors(self): + """When the target class / node coloring needs to be updated.""" self.ptree.target_class_has_changed() self._tree_specific('_update_legend_colors')() def update_size_calc(self): + """When the tree size calculation is updated.""" self._update_log_scale_slider() self.invalidate_tree() @@ -265,6 +267,7 @@ def invalidate_tree(self): self._update_main_area() def update_tooltip_enabled(self): + """When the tooltip visibility is changed and need to be updated.""" if self.tooltips_enabled: self.ptree.set_tooltip_func( self._tree_specific('_get_tooltip') @@ -274,6 +277,7 @@ def update_tooltip_enabled(self): self.ptree.tooltip_has_changed() def update_show_legend(self): + """When the legend visibility needs to be updated.""" self._update_legend_visibility() # MODEL CHANGED CONTROL ELEMENTS UPDATE METHODS @@ -350,6 +354,7 @@ def commit(self): self.send('Selected Data', data) def send_report(self): + """Send report.""" self.report_plot() def _tree_specific(self, method): @@ -449,12 +454,13 @@ def _classification_get_tooltip(self, node): return '

' \ + text \ + '{}/{} samples ({:2.3f}%)'.format( - int(samples), total, ratio * 100) \ + int(samples), total, ratio * 100) \ + '


' \ + ('Split by ' + splitting_attr.name if not self.tree_adapter.is_leaf(node.label) else '') \ - + ('

' if len(rules) and not self.tree_adapter.is_leaf( - node.label) else '') \ + + ('

' + if len(rules) and not self.tree_adapter.is_leaf(node.label) + else '') \ + rules_str \ + '

' @@ -566,10 +572,10 @@ def _regression_get_tooltip(self, node): class TreeGraphicsView( - PannableGraphicsView, - ZoomableGraphicsView, - AnchorableGraphicsView, - PreventDefaultWheelEvent + PannableGraphicsView, + ZoomableGraphicsView, + AnchorableGraphicsView, + PreventDefaultWheelEvent ): """QGraphicsView that contains all functionality we will use to display tree.""" diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index e53b9c25c6c..08bec0c1ea5 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -290,6 +290,7 @@ def commit(self): self.send('Tree', obj) def send_report(self): + """Send report.""" self.report_plot() def _update_scene_rect(self): diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py index 4015bef0a3e..555ee70d0ed 100644 --- a/Orange/widgets/visualize/tests/test_owpythagorastree.py +++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py @@ -9,6 +9,7 @@ ) +# pylint: disable=protected-access class TestPythagorasTree(unittest.TestCase): """Pythagorean tree testing, make sure calculating square positions works properly. diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/widgetutils/common/owlegend.py index 47f90f27cea..38c4a005822 100644 --- a/Orange/widgets/visualize/widgetutils/common/owlegend.py +++ b/Orange/widgets/visualize/widgetutils/common/owlegend.py @@ -92,15 +92,17 @@ def __calculate_actual_offset(self, offset): actual offset from the top left corner of the item so positioning can be done correctly.""" off_x, off_y = offset.x(), offset.y() - w, h = self.boundingRect().width(), self.boundingRect().height() + width = self.boundingRect().width() + height = self.boundingRect().height() + if self.__corner_str == self.TOP_LEFT: return QtCore.QPoint(-off_x, -off_y) elif self.__corner_str == self.TOP_RIGHT: - return QtCore.QPoint(off_x + w, -off_y) + return QtCore.QPoint(off_x + width, -off_y) elif self.__corner_str == self.BOTTOM_RIGHT: - return QtCore.QPoint(off_x + w, off_y + h) + return QtCore.QPoint(off_x + width, off_y + height) elif self.__corner_str == self.BOTTOM_LEFT: - return QtCore.QPoint(-off_x, off_y + h) + return QtCore.QPoint(-off_x, off_y + height) def __get_closest_corner(self): view = self.__get_view() @@ -110,9 +112,9 @@ def __get_closest_corner(self): legend_box = QtCore.QRect(pos, self.size().toSize()) view_box = QtCore.QRect(QtCore.QPoint(0, 0), view.size()) - def distance(t1, t2): + def distance(pt1, pt2): # 2d euclidean distance - return np.sqrt((t1.x() - t2.x()) ** 2 + (t1.y() - t2.y()) ** 2) + return np.sqrt((pt1.x() - pt2.x()) ** 2 + (pt1.y() - pt2.y()) ** 2) distances = [ (distance(getattr(view_box, corner)(), diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py index 5be8a6262d3..a384bccaed4 100644 --- a/Orange/widgets/visualize/widgetutils/tree/rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/rules.py @@ -1,4 +1,4 @@ -"""Rules for classification and regression trees. +r"""Rules for classification and regression trees. Tree visualisations usually need to show the rules of nodes, these classes make merging these rules simple (otherwise you have repeating rules e.g. `age < 3` From d98c37d7946f2947fa769f5f747ec45fc314ea55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 12 Jul 2016 12:29:01 +0200 Subject: [PATCH 06/12] Pythagorean tree: Fix tests that failed due to pylint induced renaming of variables --- Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py index abf64c727de..a098b60af41 100644 --- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py @@ -159,7 +159,7 @@ def test_merge_keeps_gt_on_continuous_rules(self): rule1 = ContinuousRule('Rule1', True, 1) rule2 = ContinuousRule('Rule1', True, 2) new_rule = rule1.merge_with(rule2) - self.assertEqual(new_rule.gt, True) + self.assertEqual(new_rule.greater, True) def test_merge_keeps_attr_name_on_continuous_rules(self): """Merging continuous rules should keep the name of the rule.""" From 266c8805972263a2c341a0b2c6bfb6ec6fd0fd64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Wed, 13 Jul 2016 12:33:53 +0200 Subject: [PATCH 07/12] Pythagorean tree: Add repr to tree rules and fix docstring example ordering --- .../visualize/widgetutils/tree/rules.py | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py index a384bccaed4..9a1f84fad75 100644 --- a/Orange/widgets/visualize/widgetutils/tree/rules.py +++ b/Orange/widgets/visualize/widgetutils/tree/rules.py @@ -44,11 +44,11 @@ class DiscreteRule(Rule): Examples -------- - Age = 30 - >>> rule = DiscreteRule('age', True, 30) + >>> DiscreteRule('age', True, 30) + age = 30 - Name ≠ John - >>> rule = DiscreteRule('name', False, 'John') + >>> DiscreteRule('name', False, 'John') + name ≠ John Notes ----- @@ -72,6 +72,8 @@ def __str__(self): return '{} {} {}'.format( self.attr_name, '=' if self.equals else '≠', self.value) + __repr__ = __str__ + class ContinuousRule(Rule): """Continuous rule class for handling numeric rules. @@ -88,11 +90,11 @@ class ContinuousRule(Rule): Examples -------- - x ≤ 30 - >>> rule = ContinuousRule('age', False, 30, inclusive=True) + >>> ContinuousRule('age', False, 30, inclusive=True) + age ≤ 30.000 - x > 30 - >>> rule = ContinuousRule('age', True, 30) + >>> ContinuousRule('age', True, 30) + age > 30.000 Notes ----- @@ -130,6 +132,8 @@ def __str__(self): return '%s %s %.3f' % ( self.attr_name, '>' if self.greater else '≤', self.value) + __repr__ = __str__ + class IntervalRule(Rule): """Interval rule class for ranges of continuous values. @@ -144,10 +148,10 @@ class IntervalRule(Rule): Examples -------- - 1 ≤ x < 3 - >>> rule = IntervalRule('Rule', - >>> ContinuousRule('Rule', True, 1, inclusive=True), - >>> ContinuousRule('Rule', False, 3)) + >>> IntervalRule('Rule', + >>> ContinuousRule('Rule', True, 1, inclusive=True), + >>> ContinuousRule('Rule', False, 3)) + Rule ∈ [1.000, 3.000) Notes ----- @@ -189,8 +193,12 @@ def merge_with(self, rule): self.right_rule.merge_with(rule.right_rule)) def __str__(self): - return '{} ∈ {}{:.3}, {:.3}{}'.format( + return '%s ∈ %s%.3f, %.3f%s' % ( self.attr_name, - '[' if self.left_rule.inclusive else '(', self.left_rule.value, - self.right_rule.value, ']' if self.right_rule.inclusive else ')' + '[' if self.left_rule.inclusive else '(', + self.left_rule.value, + self.right_rule.value, + ']' if self.right_rule.inclusive else ')' ) + + __repr__ = __str__ From 04ec528215da48723512a4f188512fb56aa7b543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Fri, 15 Jul 2016 11:42:19 +0200 Subject: [PATCH 08/12] Misc: Add tests for cache --- Orange/tests/test_misc.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 Orange/tests/test_misc.py diff --git a/Orange/tests/test_misc.py b/Orange/tests/test_misc.py new file mode 100644 index 00000000000..bf5b59f20da --- /dev/null +++ b/Orange/tests/test_misc.py @@ -0,0 +1,30 @@ +import unittest + +from Orange.misc.cache import memoize_method, single_cache + + +class Calculator: + @memoize_method() + def my_sum(self, *nums): + return sum(nums) + + +@single_cache +def my_sum(*nums): + return sum(nums) + + +class TestCache(unittest.TestCase): + + def test_single_cache(self): + self.assertEqual(my_sum(1, 2, 3, 4, 5), 15) + self.assertEqual(my_sum(1, 2, 3, 4, 5), 15) + # Make sure different args produce different results + self.assertEqual(my_sum(1, 2, 3, 4), 10) + + def test_memoize_method(self): + calc = Calculator() + self.assertEqual(calc.my_sum(1, 2, 3, 4, 5), 15) + self.assertEqual(calc.my_sum(1, 2, 3, 4, 5), 15) + # Make sure different args produce different results + self.assertEqual(calc.my_sum(1, 2, 3, 4), 10) From e03d4e8bf2a13737b69c5d6a1a9b699a2b165801 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Fri, 15 Jul 2016 11:53:08 +0200 Subject: [PATCH 09/12] Pythagorean tree: Move util classes to util folder and remove widgetutils --- Orange/widgets/visualize/owpythagorastree.py | 10 +++++----- Orange/widgets/visualize/owpythagoreanforest.py | 4 ++-- Orange/widgets/visualize/pythagorastreeviewer.py | 2 +- .../widgets/visualize/{utils.py => utils/__init__.py} | 0 .../visualize/{widgetutils/common => utils}/owgrid.py | 0 .../{widgetutils/common => utils}/owlegend.py | 0 .../visualize/{widgetutils/common => utils}/scene.py | 0 .../visualize/{widgetutils => utils/tree}/__init__.py | 0 .../visualize/{widgetutils => utils}/tree/rules.py | 0 .../{widgetutils => utils}/tree/skltreeadapter.py | 5 ++--- .../common => utils/tree/tests}/__init__.py | 0 .../{widgetutils => utils}/tree/tests/test_rules.py | 2 +- .../{widgetutils => utils}/tree/treeadapter.py | 0 .../visualize/{widgetutils/common => utils}/view.py | 0 Orange/widgets/visualize/widgetutils/tree/__init__.py | 0 .../visualize/widgetutils/tree/tests/__init__.py | 0 16 files changed, 11 insertions(+), 12 deletions(-) rename Orange/widgets/visualize/{utils.py => utils/__init__.py} (100%) rename Orange/widgets/visualize/{widgetutils/common => utils}/owgrid.py (100%) rename Orange/widgets/visualize/{widgetutils/common => utils}/owlegend.py (100%) rename Orange/widgets/visualize/{widgetutils/common => utils}/scene.py (100%) rename Orange/widgets/visualize/{widgetutils => utils/tree}/__init__.py (100%) rename Orange/widgets/visualize/{widgetutils => utils}/tree/rules.py (100%) rename Orange/widgets/visualize/{widgetutils => utils}/tree/skltreeadapter.py (98%) rename Orange/widgets/visualize/{widgetutils/common => utils/tree/tests}/__init__.py (100%) rename Orange/widgets/visualize/{widgetutils => utils}/tree/tests/test_rules.py (99%) rename Orange/widgets/visualize/{widgetutils => utils}/tree/treeadapter.py (100%) rename Orange/widgets/visualize/{widgetutils/common => utils}/view.py (100%) delete mode 100644 Orange/widgets/visualize/widgetutils/tree/__init__.py delete mode 100644 Orange/widgets/visualize/widgetutils/tree/tests/__init__.py diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 934a0808904..23619d62af0 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -19,15 +19,13 @@ from math import sqrt, log import numpy as np -from Orange.widgets.visualize.widgetutils.common.scene import \ +from Orange.widgets.visualize.utils.scene import \ UpdateItemsOnSelectGraphicsScene -from Orange.widgets.visualize.widgetutils.common.view import ( +from Orange.widgets.visualize.utils.view import ( PannableGraphicsView, ZoomableGraphicsView, PreventDefaultWheelEvent ) -from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \ - SklTreeAdapter from PyQt4 import QtGui from Orange.base import Tree @@ -40,12 +38,14 @@ PythagorasTreeViewer, SquareGraphicsItem ) -from Orange.widgets.visualize.widgetutils.common.owlegend import ( +from Orange.widgets.visualize.utils.owlegend import ( AnchorableGraphicsView, Anchorable, OWDiscreteLegend, OWContinuousLegend ) +from Orange.widgets.visualize.utils.tree.skltreeadapter import \ + SklTreeAdapter from Orange.widgets.widget import OWWidget diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index 08bec0c1ea5..6cd51f5b26d 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -14,12 +14,12 @@ from Orange.widgets import gui, settings from Orange.widgets.utils.colorpalette import ContinuousPaletteGenerator from Orange.widgets.visualize.pythagorastreeviewer import PythagorasTreeViewer -from Orange.widgets.visualize.widgetutils.common.owgrid import ( +from Orange.widgets.visualize.utils.owgrid import ( OWGrid, SelectableGridItem, ZoomableGridItem ) -from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \ +from Orange.widgets.visualize.utils.tree.skltreeadapter import \ SklTreeAdapter from Orange.widgets.widget import OWWidget diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py index 41d4ac12d3e..25de8a588f7 100644 --- a/Orange/widgets/visualize/pythagorastreeviewer.py +++ b/Orange/widgets/visualize/pythagorastreeviewer.py @@ -36,7 +36,7 @@ class PythagorasTreeViewer(QtGui.QGraphicsWidget): Examples -------- - >>> from Orange.widgets.visualize.widgetutils.tree.treeadapter import ( + >>> from Orange.widgets.visualize.utils.tree.treeadapter import ( >>> TreeAdapter >>> ) Pass tree through constructor. diff --git a/Orange/widgets/visualize/utils.py b/Orange/widgets/visualize/utils/__init__.py similarity index 100% rename from Orange/widgets/visualize/utils.py rename to Orange/widgets/visualize/utils/__init__.py diff --git a/Orange/widgets/visualize/widgetutils/common/owgrid.py b/Orange/widgets/visualize/utils/owgrid.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/common/owgrid.py rename to Orange/widgets/visualize/utils/owgrid.py diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/utils/owlegend.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/common/owlegend.py rename to Orange/widgets/visualize/utils/owlegend.py diff --git a/Orange/widgets/visualize/widgetutils/common/scene.py b/Orange/widgets/visualize/utils/scene.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/common/scene.py rename to Orange/widgets/visualize/utils/scene.py diff --git a/Orange/widgets/visualize/widgetutils/__init__.py b/Orange/widgets/visualize/utils/tree/__init__.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/__init__.py rename to Orange/widgets/visualize/utils/tree/__init__.py diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/utils/tree/rules.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/tree/rules.py rename to Orange/widgets/visualize/utils/tree/rules.py diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/utils/tree/skltreeadapter.py similarity index 98% rename from Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py rename to Orange/widgets/visualize/utils/tree/skltreeadapter.py index 980811decf3..3aed7f88680 100644 --- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py +++ b/Orange/widgets/visualize/utils/tree/skltreeadapter.py @@ -2,12 +2,11 @@ from collections import OrderedDict import numpy as np +from Orange.widgets.visualize.utils.tree.treeadapter import TreeAdapter from Orange.misc.cache import memoize_method -from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter - from Orange.preprocess.transformation import Indicator -from Orange.widgets.visualize.widgetutils.tree.rules import ( +from Orange.widgets.visualize.utils.tree.rules import ( DiscreteRule, ContinuousRule ) diff --git a/Orange/widgets/visualize/widgetutils/common/__init__.py b/Orange/widgets/visualize/utils/tree/tests/__init__.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/common/__init__.py rename to Orange/widgets/visualize/utils/tree/tests/__init__.py diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/utils/tree/tests/test_rules.py similarity index 99% rename from Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py rename to Orange/widgets/visualize/utils/tree/tests/test_rules.py index a098b60af41..562d926eae4 100644 --- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py +++ b/Orange/widgets/visualize/utils/tree/tests/test_rules.py @@ -1,7 +1,7 @@ """Test rules for classification and regression trees.""" import unittest -from Orange.widgets.visualize.widgetutils.tree.rules import ( +from Orange.widgets.visualize.utils.tree.rules import ( ContinuousRule, IntervalRule, ) diff --git a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py b/Orange/widgets/visualize/utils/tree/treeadapter.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/tree/treeadapter.py rename to Orange/widgets/visualize/utils/tree/treeadapter.py diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/utils/view.py similarity index 100% rename from Orange/widgets/visualize/widgetutils/common/view.py rename to Orange/widgets/visualize/utils/view.py diff --git a/Orange/widgets/visualize/widgetutils/tree/__init__.py b/Orange/widgets/visualize/widgetutils/tree/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py b/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 From d72ebceb951e910d288e54a41dae562fba0b4c27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Fri, 15 Jul 2016 15:08:27 +0200 Subject: [PATCH 10/12] Pythagorean tree: Added repr for rules and added prints to docstring --- Orange/widgets/visualize/utils/tree/rules.py | 27 ++++++++++++-------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/Orange/widgets/visualize/utils/tree/rules.py b/Orange/widgets/visualize/utils/tree/rules.py index 9a1f84fad75..75aaac946ab 100644 --- a/Orange/widgets/visualize/utils/tree/rules.py +++ b/Orange/widgets/visualize/utils/tree/rules.py @@ -44,10 +44,10 @@ class DiscreteRule(Rule): Examples -------- - >>> DiscreteRule('age', True, 30) + >>> print(DiscreteRule('age', True, 30)) age = 30 - >>> DiscreteRule('name', False, 'John') + >>> print(DiscreteRule('name', False, 'John')) name ≠ John Notes @@ -72,7 +72,9 @@ def __str__(self): return '{} {} {}'.format( self.attr_name, '=' if self.equals else '≠', self.value) - __repr__ = __str__ + def __repr__(self): + return "DiscreteRule(attr_name='%s', equals=%s, value=%s)" % ( + self.attr_name, self.equals, self.value) class ContinuousRule(Rule): @@ -90,10 +92,10 @@ class ContinuousRule(Rule): Examples -------- - >>> ContinuousRule('age', False, 30, inclusive=True) + >>> print(ContinuousRule('age', False, 30, inclusive=True)) age ≤ 30.000 - >>> ContinuousRule('age', True, 30) + >>> print(ContinuousRule('age', True, 30)) age > 30.000 Notes @@ -132,7 +134,10 @@ def __str__(self): return '%s %s %.3f' % ( self.attr_name, '>' if self.greater else '≤', self.value) - __repr__ = __str__ + def __repr__(self): + return "ContinuousRule(attr_name='%s', greater=%s, value=%s, " \ + "inclusive=%s)" % (self.attr_name, self.greater, self.value, + self.inclusive) class IntervalRule(Rule): @@ -148,9 +153,9 @@ class IntervalRule(Rule): Examples -------- - >>> IntervalRule('Rule', - >>> ContinuousRule('Rule', True, 1, inclusive=True), - >>> ContinuousRule('Rule', False, 3)) + >>> print(IntervalRule('Rule', + >>> ContinuousRule('Rule', True, 1, inclusive=True), + >>> ContinuousRule('Rule', False, 3))) Rule ∈ [1.000, 3.000) Notes @@ -201,4 +206,6 @@ def __str__(self): ']' if self.right_rule.inclusive else ')' ) - __repr__ = __str__ + def __repr__(self): + return "IntervalRule(attr_name='%s', left_rule=%s, right_rule=%s)" % ( + self.attr_name, repr(self.left_rule), repr(self.right_rule)) From 3102f7f25e84360e4fc2b977ac7caee5b4b71ce6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Mon, 18 Jul 2016 09:44:52 +0200 Subject: [PATCH 11/12] Pythagorean tree: Add cosmetic fix --- Orange/widgets/visualize/utils/tree/rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Orange/widgets/visualize/utils/tree/rules.py b/Orange/widgets/visualize/utils/tree/rules.py index 75aaac946ab..a491da2893e 100644 --- a/Orange/widgets/visualize/utils/tree/rules.py +++ b/Orange/widgets/visualize/utils/tree/rules.py @@ -123,7 +123,7 @@ def merge_with(self, rule): return ContinuousRule(self.attr_name, self.greater, larger) # When both are LT else: - smaller = self.value if self.value < rule.value else rule.value + smaller = min(self.value, rule.value) return ContinuousRule(self.attr_name, self.greater, smaller) # When they have different signs we need to return an interval rule else: From 596a4526865c8c1a0195cda955ad64ff8ad37c65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?= Date: Tue, 19 Jul 2016 14:48:50 +0200 Subject: [PATCH 12/12] Pythagorean tree: Change text in target class with regression, make forest name camel case --- Orange/widgets/visualize/owpythagorastree.py | 8 ++++++++ Orange/widgets/visualize/owpythagoreanforest.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 23619d62af0..ed7e1e43e6b 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -387,6 +387,10 @@ def _tree_specific(self, method): # CLASSIFICATION TREE SPECIFIC METHODS def _classification_update_target_class_combo(self): self._clear_target_class_combo() + list(filter( + lambda x: isinstance(x, QtGui.QLabel), + self.target_class_combo.parent().children() + ))[0].setText('Target class') self.target_class_combo.addItem('None') values = [c.title() for c in self.tree_adapter.domain.class_vars[0].values] @@ -467,6 +471,10 @@ def _classification_get_tooltip(self, node): # REGRESSION TREE SPECIFIC METHODS def _regression_update_target_class_combo(self): self._clear_target_class_combo() + list(filter( + lambda x: isinstance(x, QtGui.QLabel), + self.target_class_combo.parent().children() + ))[0].setText('Node color') self.target_class_combo.addItems( list(zip(*self.REGRESSION_COLOR_CALC))[0]) self.target_class_combo.setCurrentIndex(self.target_class_index) diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index 6cd51f5b26d..f37ac8210c5 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -25,7 +25,7 @@ class OWPythagoreanForest(OWWidget): - name = 'Pythagorean forest' + name = 'Pythagorean Forest' description = 'Pythagorean forest for visualising random forests.' icon = 'icons/PythagoreanForest.svg'