' \
+ + ('Split by ' + splitting_attr.name
+ if not self.tree_adapter.is_leaf(node.label) else '') \
+ + ('
' if len(rules) and not self.tree_adapter.is_leaf(
+ node.label) else '') \
+ + rules_str \
+ + ''
+
+
+class TreeGraphicsView(
+ PannableGraphicsView,
+ ZoomableGraphicsView,
+ AnchorableGraphicsView,
+ PreventDefaultWheelEvent
+):
+ pass
+
+
+class TreeGraphicsScene(UpdateItemsOnSelectGraphicsScene):
+ pass
+
+
+def main():
+ import sys
+ import Orange
+ from Orange.classification.tree import TreeLearner
+
+ argv = sys.argv
+ if len(argv) > 1:
+ filename = argv[1]
+ else:
+ filename = 'iris'
+
+ app = QtGui.QApplication(argv)
+ ow = OWPythagorasTree()
+ data = Orange.data.Table(filename)
+ clf = TreeLearner(max_depth=1000)(data)
+ ow.set_tree(clf)
+
+ ow.show()
+ ow.raise_()
+ ow.handleNewSignals()
+ app.exec_()
+
+ sys.exit(0)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py
new file mode 100644
index 00000000000..d007aa61eaf
--- /dev/null
+++ b/Orange/widgets/visualize/owpythagoreanforest.py
@@ -0,0 +1,415 @@
+from math import log, sqrt
+
+import numpy as np
+from PyQt4 import QtGui
+from PyQt4.QtCore import Qt
+
+from Orange.base import RandomForest, Tree
+from Orange.classification.random_forest import RandomForestClassifier
+from Orange.classification.tree import TreeClassifier
+from Orange.data import Table
+from Orange.regression.random_forest import RandomForestRegressor
+from Orange.regression.tree import TreeRegressor
+from Orange.widgets import gui, settings
+from Orange.widgets.utils.colorpalette import ContinuousPaletteGenerator
+from Orange.widgets.visualize.pythagorastreeviewer import PythagorasTreeViewer
+from Orange.widgets.visualize.widgetutils.common.owgrid import (
+ OWGrid,
+ SelectableGridItem,
+ ZoomableGridItem
+)
+from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \
+ SklTreeAdapter
+from Orange.widgets.widget import OWWidget
+
+
+class OWPythagoreanForest(OWWidget):
+ name = 'Pythagorean forest'
+ description = 'Pythagorean forest for visualising random forests.'
+ icon = 'icons/PythagoreanForest.svg'
+
+ priority = 620
+
+ inputs = [('Random forest', RandomForest, 'set_rf')]
+ outputs = [('Tree', Tree)]
+
+ # Enable the save as feature
+ graph_name = 'scene'
+
+ # Settings
+ depth_limit = settings.ContextSetting(10)
+ target_class_index = settings.ContextSetting(0)
+ size_calc_idx = settings.Setting(0)
+ size_log_scale = settings.Setting(2)
+ zoom = settings.Setting(50)
+ selected_tree_index = settings.ContextSetting(-1)
+
+ CLASSIFICATION, REGRESSION = range(2)
+
+ def __init__(self):
+ super().__init__()
+ # Instance variables
+ self.forest_type = self.CLASSIFICATION
+ self.model = None
+ self.forest_adapter = None
+ self.dataset = None
+ self.clf_dataset = None
+ # We need to store refernces to the trees and grid items
+ self.grid_items, self.ptrees = [], []
+
+ self.color_palette = None
+
+ # Different methods to calculate the size of squares
+ self.SIZE_CALCULATION = [
+ ('Normal', lambda x: x),
+ ('Square root', lambda x: sqrt(x)),
+ ('Logarithmic', lambda x: log(x * self.size_log_scale)),
+ ]
+
+ self.REGRESSION_COLOR_CALC = [
+ ('None', lambda _, __: QtGui.QColor(255, 255, 255)),
+ ('Class mean', self._color_class_mean),
+ ('Standard deviation', self._color_stddev),
+ ]
+
+ # CONTROL AREA
+ # Tree info area
+ box_info = gui.widgetBox(self.controlArea, 'Forest')
+ self.ui_info = gui.widgetLabel(box_info, label='')
+
+ # Display controls area
+ box_display = gui.widgetBox(self.controlArea, 'Display')
+ self.ui_depth_slider = gui.hSlider(
+ box_display, self, 'depth_limit', label='Depth', ticks=False,
+ callback=self.max_depth_changed)
+ self.ui_target_class_combo = gui.comboBox(
+ box_display, self, 'target_class_index', label='Target class',
+ orientation='horizontal', items=[], contentsLength=8,
+ callback=self.target_colors_changed)
+ self.ui_size_calc_combo = gui.comboBox(
+ box_display, self, 'size_calc_idx', label='Size',
+ orientation='horizontal',
+ items=list(zip(*self.SIZE_CALCULATION))[0], contentsLength=8,
+ callback=self.size_calc_changed)
+ self.ui_zoom_slider = gui.hSlider(
+ box_display, self, 'zoom', label='Zoom', ticks=False, minValue=20,
+ maxValue=150, callback=self.zoom_changed, createLabel=False)
+
+ # Stretch to fit the rest of the unsused area
+ gui.rubber(self.controlArea)
+
+ self.controlArea.setSizePolicy(
+ QtGui.QSizePolicy.Preferred, QtGui.QSizePolicy.Expanding)
+
+ # MAIN AREA
+ self.scene = QtGui.QGraphicsScene(self)
+ self.scene.selectionChanged.connect(self.commit)
+ self.grid = OWGrid()
+ self.grid.geometryChanged.connect(self._update_scene_rect)
+ self.scene.addItem(self.grid)
+
+ self.view = QtGui.QGraphicsView(self.scene)
+ self.view.setRenderHint(QtGui.QPainter.Antialiasing, True)
+ self.view.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
+ self.mainArea.layout().addWidget(self.view)
+
+ self.resize(800, 500)
+
+ self.clear()
+
+ def set_rf(self, model=None):
+ """When a different forest is given."""
+ self.clear()
+ self.model = model
+
+ if model is not None:
+ if isinstance(model, RandomForestClassifier):
+ self.forest_type = self.CLASSIFICATION
+ elif isinstance(model, RandomForestRegressor):
+ self.forest_type = self.REGRESSION
+ else:
+ raise RuntimeError('Invalid type of forest.')
+
+ self.forest_adapter = self._get_forest_adapter(self.model)
+ self.color_palette = self._type_specific('_get_color_palette')()
+ self._draw_trees()
+
+ self.dataset = model.instances
+ # this bit is important for the regression classifier
+ if self.dataset is not None and \
+ self.dataset.domain != model.domain:
+ self.clf_dataset = Table.from_table(
+ self.model.domain, self.dataset)
+ else:
+ self.clf_dataset = self.dataset
+
+ self._update_info_box()
+ self._type_specific('_update_target_class_combo')()
+ self._update_depth_slider()
+
+ self.selected_tree_index = -1
+
+ def clear(self):
+ """Clear all relevant data from the widget."""
+ self.model = None
+ self.forest_adapter = None
+ self.ptrees = []
+ self.grid_items = []
+ self.grid.clear()
+
+ self._clear_info_box()
+ self._clear_target_class_combo()
+ self._clear_depth_slider()
+
+ # CONTROL AREA CALLBACKS
+ def max_depth_changed(self):
+ for tree in self.ptrees:
+ tree.set_depth_limit(self.depth_limit)
+
+ def target_colors_changed(self):
+ for tree in self.ptrees:
+ tree.target_class_has_changed()
+
+ def size_calc_changed(self):
+ if self.model is not None:
+ self.forest_adapter = self._get_forest_adapter(self.model)
+ self.grid.clear()
+ self._draw_trees()
+ # Keep the selected item
+ if self.selected_tree_index != -1:
+ self.grid_items[self.selected_tree_index].setSelected(True)
+ self.max_depth_changed()
+
+ def zoom_changed(self):
+ for item in self.grid_items:
+ item.set_max_size(self._calculate_zoom(self.zoom))
+
+ width = (self.view.width() - self.view.verticalScrollBar().width())
+ self.grid.reflow(width)
+ self.grid.setPreferredWidth(width)
+
+ # MODEL CHANGED METHODS
+ def _update_info_box(self):
+ self.ui_info.setText(
+ 'Trees: {}'.format(len(self.forest_adapter.get_trees()))
+ )
+
+ def _update_depth_slider(self):
+ self.depth_limit = self._get_max_depth()
+
+ self.ui_depth_slider.parent().setEnabled(True)
+ self.ui_depth_slider.setMaximum(self.depth_limit)
+ self.ui_depth_slider.setValue(self.depth_limit)
+
+ # MODEL CLEARED METHODS
+ def _clear_info_box(self):
+ self.ui_info.setText('No forest on input.')
+
+ def _clear_target_class_combo(self):
+ self.ui_target_class_combo.clear()
+ self.target_class_index = 0
+ self.ui_target_class_combo.setCurrentIndex(self.target_class_index)
+
+ def _clear_depth_slider(self):
+ self.ui_depth_slider.parent().setEnabled(False)
+ self.ui_depth_slider.setMaximum(0)
+
+ # HELPFUL METHODS
+ def _get_max_depth(self):
+ return max([tree.tree_adapter.max_depth for tree in self.ptrees])
+
+ def _get_forest_adapter(self, model):
+ return SklRandomForestAdapter(
+ model,
+ model.domain,
+ adjust_weight=self.SIZE_CALCULATION[self.size_calc_idx][1],
+ )
+
+ def _draw_trees(self):
+ self.grid_items, self.ptrees = [], []
+
+ with self.progressBar(len(self.forest_adapter.get_trees())) as prg:
+ for tree in self.forest_adapter.get_trees():
+ ptree = PythagorasTreeViewer(
+ None, tree,
+ node_color_func=self._type_specific('_get_node_color'),
+ interactive=False, padding=100)
+ self.grid_items.append(GridItem(
+ ptree, self.grid, max_size=self._calculate_zoom(self.zoom)
+ ))
+ self.ptrees.append(ptree)
+ prg.advance()
+ self.grid.set_items(self.grid_items)
+ # This is necessary when adding items for the first time
+ if self.grid:
+ width = (self.view.width() -
+ self.view.verticalScrollBar().width())
+ self.grid.reflow(width)
+ self.grid.setPreferredWidth(width)
+
+ @staticmethod
+ def _calculate_zoom(zoom_level):
+ """Calculate the max size for grid items from zoom level setting."""
+ return zoom_level * 5
+
+ def onDeleteWidget(self):
+ """When deleting the widget."""
+ super().onDeleteWidget()
+ self.clear()
+
+ def commit(self):
+ """Commit the selected tree to output."""
+ if len(self.scene.selectedItems()) == 0:
+ self.send('Tree', None)
+ # The selected tree index should only reset when model changes
+ if self.model is None:
+ self.selected_tree_index = -1
+ return
+
+ selected_item = self.scene.selectedItems()[0]
+ self.selected_tree_index = self.grid_items.index(selected_item)
+ tree = self.model.skl_model.estimators_[self.selected_tree_index]
+
+ if self.forest_type == self.CLASSIFICATION:
+ obj = TreeClassifier(tree)
+ else:
+ obj = TreeRegressor(tree)
+ obj.domain = self.model.domain
+ obj.instances = self.model.instances
+
+ obj.meta_target_class_index = self.target_class_index
+ obj.meta_size_calc_idx = self.size_calc_idx
+ obj.meta_size_log_scale = self.size_log_scale
+ obj.meta_depth_limit = self.depth_limit
+
+ self.send('Tree', obj)
+
+ def send_report(self):
+ self.report_plot()
+
+ def _update_scene_rect(self):
+ self.scene.setSceneRect(self.scene.itemsBoundingRect())
+
+ def resizeEvent(self, ev):
+ width = (self.view.width() - self.view.verticalScrollBar().width())
+ self.grid.reflow(width)
+ self.grid.setPreferredWidth(width)
+
+ super().resizeEvent(ev)
+
+ def _type_specific(self, method):
+ """A best effort method getter that somewhat separates logic specific
+ to classification and regression trees.
+ This relies on conventional naming of specific methods, e.g.
+ a method name _get_tooltip would need to be defined like so:
+ _classification_get_tooltip and _regression_get_tooltip, since they are
+ both specific.
+
+ Parameters
+ ----------
+ method : str
+ Method name that we would like to call.
+
+ Returns
+ -------
+ callable or None
+
+ """
+ if self.forest_type == self.CLASSIFICATION:
+ return getattr(self, '_classification' + method)
+ elif self.forest_type == self.REGRESSION:
+ return getattr(self, '_regression' + method)
+ else:
+ return None
+
+ # CLASSIFICATION FOREST SPECIFIC METHODS
+ def _classification_update_target_class_combo(self):
+ self._clear_target_class_combo()
+ self.ui_target_class_combo.addItem('None')
+ values = [c.title() for c in
+ self.model.domain.class_vars[0].values]
+ self.ui_target_class_combo.addItems(values)
+
+ def _classification_get_color_palette(self):
+ return [QtGui.QColor(*c) for c in self.model.domain.class_var.colors]
+
+ def _classification_get_node_color(self, adapter, tree_node):
+ # this is taken almost directly from the existing classification tree
+ # viewer
+ colors = self.color_palette
+ distribution = adapter.get_distribution(tree_node.label)[0]
+ total = np.sum(distribution)
+
+ if self.target_class_index:
+ p = distribution[self.target_class_index - 1] / total
+ color = colors[self.target_class_index - 1].light(200 - 100 * p)
+ else:
+ modus = np.argmax(distribution)
+ p = distribution[modus] / (total or 1)
+ color = colors[int(modus)].light(400 - 300 * p)
+ return color
+
+ # REGRESSION FOREST SPECIFIC METHODS
+ def _regression_update_target_class_combo(self):
+ self._clear_target_class_combo()
+ self.ui_target_class_combo.addItems(
+ list(zip(*self.REGRESSION_COLOR_CALC))[0])
+ self.ui_target_class_combo.setCurrentIndex(self.target_class_index)
+
+ def _regression_get_color_palette(self):
+ return ContinuousPaletteGenerator(
+ *self.forest_adapter.domain.class_var.colors)
+
+ def _regression_get_node_color(self, adapter, tree_node):
+ return self.REGRESSION_COLOR_CALC[self.target_class_index][1](
+ adapter, tree_node
+ )
+
+ def _color_class_mean(self, adapter, tree_node):
+ # calculate node colors relative to the mean of the node samples
+ min_mean = np.min(self.clf_dataset.Y)
+ max_mean = np.max(self.clf_dataset.Y)
+ instances = adapter.get_instances_in_nodes(self.clf_dataset, tree_node)
+ mean = np.mean(instances.Y)
+
+ return self.color_palette[(mean - min_mean) / (max_mean - min_mean)]
+
+ def _color_stddev(self, adapter, tree_node):
+ # calculate node colors relative to the standard deviation in the node
+ # samples
+ min_mean, max_mean = 0, np.std(self.clf_dataset.Y)
+ instances = adapter.get_instances_in_nodes(self.clf_dataset, tree_node)
+ std = np.std(instances.Y)
+
+ return self.color_palette[(std - min_mean) / (max_mean - min_mean)]
+
+
+class GridItem(SelectableGridItem, ZoomableGridItem):
+ pass
+
+
+class SklRandomForestAdapter:
+ def __init__(self, model, domain, adjust_weight=lambda x: x):
+ self._adapters = []
+
+ self._domain = domain
+
+ self._trees = model.skl_model.estimators_
+ self._domain = model.domain
+ self._adjust_weight = adjust_weight
+
+ def get_trees(self):
+ if len(self._adapters) > 0:
+ return self._adapters
+ if len(self._trees) < 1:
+ return self._adapters
+
+ self._adapters = [
+ SklTreeAdapter(tree.tree_, self._domain, self._adjust_weight)
+ for tree in self._trees
+ ]
+ return self._adapters
+
+ @property
+ def domain(self):
+ return self._domain
diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py
new file mode 100644
index 00000000000..42e79119ac2
--- /dev/null
+++ b/Orange/widgets/visualize/pythagorastreeviewer.py
@@ -0,0 +1,756 @@
+"""
+Pythagoras tree viewer for visualizing tree structures.
+
+The pythagoras tree viewer widget is a widget that can be plugged into any
+existing widget given a tree adapter instance. It is simply a canvas that takes
+and input tree adapter and takes care of all the drawing.
+
+Types
+-----
+Square : namedtuple (center, length, angle)
+ Since Pythagoras trees deal only with squares (they also deal with
+ rectangles in the generalized form, but are completely unreadable), this
+ is what all the squares are stored as.
+Point : namedtuple (x, y)
+ Self exaplanatory.
+
+"""
+from collections import namedtuple, defaultdict, deque
+from math import pi, sqrt, cos, sin, degrees
+
+from PyQt4 import QtCore, QtGui
+from PyQt4.QtCore import Qt
+
+from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter
+
+# z index range, increase if needed
+Z_STEP = 5000000
+
+Square = namedtuple('Square', ['center', 'length', 'angle'])
+Point = namedtuple('Point', ['x', 'y'])
+
+
+class PythagorasTreeViewer(QtGui.QGraphicsWidget):
+ """Pythagoras tree viewer graphics widget.
+
+ Simply pass in a tree adapter instance and a valid scene object, and the
+ pythagoras tree will be added.
+
+ Examples
+ --------
+ Pass tree through constructor.
+ >>> tree_view = PythagorasTreeViewer(parent=scene, adapter=tree_adapter)
+
+ Pass tree later through method.
+ >>> tree_adapter = TreeAdapter()
+ >>> scene = QtGui.QGraphicsScene()
+ This is where the magic happens
+ >>> tree_view = PythagorasTreeViewer(parent=scene)
+ >>> tree_view.set_tree(tree_adapter)
+
+ Both these examples set the appropriate tree and add all the squares to the
+ widget instance.
+
+ Parameters
+ ----------
+ parent : QGraphicsItem, optional
+ The parent object that the graphics widget belongs to. Should be a
+ scene.
+ adapter : TreeAdapter, optional
+ Any valid tree adapter instance.
+ interacitive : bool, optional,
+ Specify whether the widget should have an interactive display. This
+ means special hover effects, selectable boxes. Default is true.
+
+ Notes
+ -----
+ .. Note:: The class contains two clear methods: `clear` and `clear_tree`.
+ Each has their own use.
+ `clear_tree` will clear out the tree and remove any graphics items.
+ `clear` will, on the other hand, clear everything, all settings
+ (tooltip and color calculation functions.
+
+ This is useful because when we want to change the size calculation of
+ the Pythagora tree, we just want to clear the scene and it would be
+ inconvenient to have to set color and tooltip functions again.
+ On the other hand, when we want to draw a brand new tree, it is best
+ to clear all settings to avoid any strange bugs - we start with a blank
+ slate.
+
+ """
+
+ def __init__(self, parent=None, adapter=None, depth_limit=0, padding=0,
+ **kwargs):
+ super().__init__(parent)
+
+ # Instance variables
+ # The tree adapter parameter will be handled at the end of init
+ self.tree_adapter = None
+ # The root tree node instance which is calculated inside the class
+ self._tree = None
+ self._padding = padding
+
+ self.setSizePolicy(QtGui.QSizePolicy.Expanding,
+ QtGui.QSizePolicy.Expanding)
+
+ # Necessary settings that need to be set from the outside
+ self._depth_limit = depth_limit
+ # Provide a nice green default in case no color function is provided
+ self.__calc_node_color_func = kwargs.get('node_color_func')
+ self.__get_tooltip_func = kwargs.get('tooltip_func')
+ self._interactive = kwargs.get('interactive', True)
+
+ self._square_objects = {}
+ self._drawn_nodes = deque()
+ self._frontier = deque()
+
+ # If a tree adapter was passed, set and draw the tree
+ if adapter is not None:
+ self.set_tree(adapter)
+
+ def set_tree(self, tree_adapter):
+ """Pass in a new tree adapter instance and perform updates to canvas.
+
+ Parameters
+ ----------
+ tree_adapter : TreeAdapter
+ The new tree adapter that is to be used.
+
+ Returns
+ -------
+
+ """
+ self.clear_tree()
+ self.tree_adapter = tree_adapter
+
+ if self.tree_adapter is not None:
+ self._tree = self._calculate_tree(self.tree_adapter)
+ self.set_depth_limit(tree_adapter.max_depth)
+ self._draw_tree(self._tree)
+
+ def set_depth_limit(self, depth):
+ """Update the drawing depth limit.
+
+ The drawing stops when the depth is GT the limit. This means that at
+ depth 0, the root node will be drawn.
+
+ Parameters
+ ----------
+ depth : int
+ The maximum depth at which the nodes can still be drawn.
+
+ Returns
+ -------
+
+ """
+ self._depth_limit = depth
+ self._draw_tree(self._tree)
+
+ def set_node_color_func(self, func):
+ """Set the function that will be used to calculate the node colors.
+
+ The function must accept one parameter that represents the label of a
+ given node and return the appropriate QColor object that should be used
+ for the node.
+
+ Parameters
+ ----------
+ func : Callable
+ func :: label -> QtGui.QColor
+
+ Returns
+ -------
+
+ """
+ if func != self._calc_node_color:
+ self.__calc_node_color_func = func
+ self._update_node_colors()
+
+ def _calc_node_color(self, *args):
+ """Get the node color with a nice default fallback."""
+ if self.__calc_node_color_func is not None:
+ return self.__calc_node_color_func(*args)
+ return QtGui.QColor('#297A1F')
+
+ def set_tooltip_func(self, func):
+ """Set the function that will be used the get the node tooltips.
+
+ Parameters
+ ----------
+ func : Callable
+ func :: label -> str
+
+ Returns
+ -------
+
+ """
+ if func != self._get_tooltip:
+ self.__get_tooltip_func = func
+ self._update_node_tooltips()
+
+ def _get_tooltip(self, *args):
+ """Get the node tooltip with a nice default fallback."""
+ if self.__get_tooltip_func is not None:
+ return self.__get_tooltip_func(*args)
+ return 'Tooltip'
+
+ def target_class_has_changed(self):
+ self._update_node_colors()
+ self._update_node_tooltips()
+
+ def tooltip_has_changed(self):
+ self._update_node_tooltips()
+
+ def _update_node_colors(self):
+ """Update all the node colors.
+
+ Should be called when the color method is changed and the nodes need to
+ be drawn with the new colors.
+
+ Returns
+ -------
+
+ """
+ for square in self._squares():
+ square.setBrush(self._calc_node_color(self.tree_adapter,
+ square.tree_node))
+
+ def _update_node_tooltips(self):
+ """Update all the tooltips for the squares."""
+ for square in self._squares():
+ square.setToolTip(self._get_tooltip(square.tree_node))
+
+ def clear(self):
+ """Clear the entire widget state."""
+ self.__calc_node_color_func = None
+ self.__get_tooltip_func = None
+ self.clear_tree()
+
+ def clear_tree(self):
+ """Clear only the tree, keeping tooltip and color functions."""
+ self.tree_adapter = None
+ self._tree = None
+ self._clear_scene()
+
+ @staticmethod
+ def _calculate_tree(tree_adapter):
+ """Actually calculate the tree squares"""
+ tree_builder = PythagorasTree()
+ return tree_builder.pythagoras_tree(
+ tree_adapter, tree_adapter.root, Square(Point(0, 0), 200, -pi / 2)
+ )
+
+ def _draw_tree(self, root):
+ """Efficiently draw the tree with regards to the depth.
+
+ If using a recursive approach, the tree had to be redrawn every time
+ the depth was changed, which was very impractical for larger trees,
+ since everything got very slow, very fast.
+
+ In this approach, we use two queues to represent the tree frontier and
+ the nodes that have already been drawn. We also store the depth. This
+ way, when the max depth is increased, we do not redraw the whole tree
+ but only iterate throught the frontier and draw those nodes, and update
+ the frontier accordingly.
+ When decreasing the max depth, we reverse the process, we clear the
+ frontier, and remove nodes from the drawn nodes, and append those with
+ depth max_depth + 1 to the frontier, so the frontier doesn't get
+ cluttered.
+
+ Parameters
+ ----------
+ root : TreeNode
+ The root tree node.
+
+ Returns
+ -------
+
+ """
+ if self._tree is None:
+ return
+ # if this is the first time drawing the tree begin with root
+ if not self._drawn_nodes:
+ self._frontier.appendleft((0, root))
+ # if the depth was decreased, we can clear the frontier, otherwise
+ # frontier gets cluttered with non-frontier nodes
+ was_decreased = self._depth_was_decreased()
+ if was_decreased:
+ self._frontier.clear()
+ # remove nodes from drawn and add to frontier if limit is decreased
+ while self._drawn_nodes:
+ depth, node = self._drawn_nodes.pop()
+ # check if the node is in the allowed limit
+ if depth <= self._depth_limit:
+ self._drawn_nodes.append((depth, node))
+ break
+ if depth == self._depth_limit + 1:
+ self._frontier.appendleft((depth, node))
+
+ if node.label in self._square_objects:
+ self._square_objects[node.label].hide()
+
+ # add nodes to drawn and remove from frontier if limit is increased
+ while self._frontier:
+ depth, node = self._frontier.popleft()
+ # check if the depth of the node is outside the allowed limit
+ if depth > self._depth_limit:
+ self._frontier.appendleft((depth, node))
+ break
+ self._drawn_nodes.append((depth, node))
+ self._frontier.extend((depth + 1, c) for c in node.children)
+
+ if node.label in self._square_objects:
+ self._square_objects[node.label].show()
+ else:
+ square_obj = InteractiveSquareGraphicsItem \
+ if self._interactive else SquareGraphicsItem
+ self._square_objects[node.label] = square_obj(
+ node,
+ parent=self,
+ brush=QtGui.QBrush(
+ self._calc_node_color(self.tree_adapter, node)
+ ),
+ tooltip=self._get_tooltip(node),
+ zvalue=depth,
+ )
+
+ def _depth_was_decreased(self):
+ if not self._drawn_nodes:
+ return False
+ # checks if the max depth was increased from the last change
+ depth, node = self._drawn_nodes.pop()
+ self._drawn_nodes.append((depth, node))
+ # if the right most node in drawn nodes has appropriate depth, it must
+ # have been increased
+ return depth > self._depth_limit
+
+ def _squares(self):
+ return [node.graphics_item for _, node in self._drawn_nodes]
+
+ def _clear_scene(self):
+ for square in self._squares():
+ self.scene().removeItem(square)
+ self._frontier.clear()
+ self._drawn_nodes.clear()
+ self._square_objects.clear()
+
+ def boundingRect(self):
+ return self.childrenBoundingRect().adjusted(
+ -self._padding, -self._padding, self._padding, self._padding)
+
+ def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
+ return self.boundingRect().size() + \
+ QtCore.QSizeF(self._padding, self._padding)
+
+
+class SquareGraphicsItem(QtGui.QGraphicsRectItem):
+ """Square Graphics Item.
+
+ Square component to draw as components for the non-interactive Pythagoras
+ tree.
+
+ Parameters
+ ----------
+ tree_node : TreeNode
+ The tree node the square represents.
+ brush : QColor, optional
+ The brush to be used as the backgound brush.
+ pen : QPen, optional
+ The pen to be used for the border.
+
+ """
+
+ def __init__(self, tree_node, parent=None, **kwargs):
+ self.tree_node = tree_node
+ self.tree_node.graphics_item = self
+
+ center, length, angle = tree_node.square
+ self._center_point = center
+ self.center = QtCore.QPointF(*center)
+ self.length = length
+ self.angle = angle
+ super().__init__(self._get_rect_attributes(), parent)
+ self.setTransformOriginPoint(self.boundingRect().center())
+ self.setRotation(degrees(angle))
+
+ self.setBrush(kwargs.get('brush', QtGui.QColor('#297A1F')))
+ self.setPen(kwargs.get('pen', QtGui.QPen(QtGui.QColor('#000'))))
+
+ self.setAcceptHoverEvents(True)
+ self.setZValue(kwargs.get('zvalue', 0))
+ self.z_step = Z_STEP
+
+ # calculate the correct z values based on the parent
+ if self.tree_node.parent != -1:
+ p = self.tree_node.parent
+ # override root z step
+ num_children = len(p.children)
+ own_index = [1 if c.label == self.tree_node.label else 0
+ for c in p.children].index(1)
+
+ self.z_step = int(p.graphics_item.z_step / num_children)
+ base_z = p.graphics_item.zValue()
+
+ self.setZValue(base_z + own_index * self.z_step)
+
+ def _get_rect_attributes(self):
+ """Get the rectangle attributes requrired to draw item.
+
+ Compute the QRectF that a QGraphicsRect needs to be rendered with the
+ data passed down in the constructor.
+
+ """
+ height = width = self.length
+ x = self.center.x() - self.length / 2
+ y = self.center.y() - self.length / 2
+ return QtCore.QRectF(x, y, height, width)
+
+
+class InteractiveSquareGraphicsItem(SquareGraphicsItem):
+ """Interactive square graphics items.
+
+ This is different from the base square graphics item so that it is
+ selectable, and it can handle and react to hover events (highlight and
+ focus own branch).
+
+ Parameters
+ ----------
+ tree_node : TreeNode
+ The tree node the square represents.
+ brush : QColor, optional
+ The brush to be used as the backgound brush.
+ pen : QPen, optional
+ The pen to be used for the border.
+
+ """
+
+ timer = QtCore.QTimer()
+
+ MAX_OPACITY = 1.
+ SELECTION_OPACITY = .5
+ HOVER_OPACITY = .1
+
+ def __init__(self, tree_node, parent=None, **kwargs):
+ super().__init__(tree_node, parent, **kwargs)
+ self.setFlag(QtGui.QGraphicsItem.ItemIsSelectable, True)
+
+ self.initial_zvalue = self.zValue()
+ # The max z value changes if any item is selected
+ self.any_selected = False
+
+ self.setToolTip(kwargs.get('tooltip', 'Tooltip'))
+
+ self.timer.setSingleShot(True)
+
+ def hoverEnterEvent(self, ev):
+ self.timer.stop()
+
+ def fnc(graphics_item):
+ graphics_item.setZValue(Z_STEP)
+ if self.any_selected:
+ if graphics_item.isSelected():
+ opacity = self.MAX_OPACITY
+ else:
+ opacity = self.SELECTION_OPACITY
+ else:
+ opacity = self.MAX_OPACITY
+ graphics_item.setOpacity(opacity)
+
+ def other_fnc(graphics_item):
+ if graphics_item.isSelected():
+ opacity = self.MAX_OPACITY
+ else:
+ opacity = self.HOVER_OPACITY
+ graphics_item.setOpacity(opacity)
+ graphics_item.setZValue(self.initial_zvalue)
+
+ self._propagate_z_values(self, fnc, other_fnc)
+
+ def hoverLeaveEvent(self, ev):
+
+ def fnc(graphics_item):
+ # No need to set opacity in this branch since it was just selected
+ # and had the max value
+ graphics_item.setZValue(self.initial_zvalue)
+
+ def other_fnc(graphics_item):
+ if self.any_selected:
+ if graphics_item.isSelected():
+ opacity = self.MAX_OPACITY
+ else:
+ opacity = self.SELECTION_OPACITY
+ else:
+ opacity = self.MAX_OPACITY
+ graphics_item.setOpacity(opacity)
+
+ self.timer.timeout.connect(
+ lambda: self._propagate_z_values(self, fnc, other_fnc))
+
+ self.timer.start(250)
+
+ def _propagate_z_values(self, graphics_item, fnc, other_fnc):
+ self._propagate_to_children(graphics_item, fnc)
+ self._propagate_to_parents(graphics_item, fnc, other_fnc)
+
+ def _propagate_to_children(self, graphics_item, fnc):
+ # propagate function that handles graphics item to appropriate children
+ fnc(graphics_item)
+ for c in graphics_item.tree_node.children:
+ self._propagate_to_children(c.graphics_item, fnc)
+
+ def _propagate_to_parents(self, graphics_item, fnc, other_fnc):
+ # propagate function that handles graphics item to appropriate parents
+ if graphics_item.tree_node.parent != -1:
+ parent = graphics_item.tree_node.parent.graphics_item
+ # handle the non relevant children nodes
+ for c in parent.tree_node.children:
+ if c != graphics_item.tree_node:
+ self._propagate_to_children(c.graphics_item, other_fnc)
+ # handle the parent node
+ fnc(parent)
+ # propagate up the tree
+ self._propagate_to_parents(parent, fnc, other_fnc)
+
+ def selection_changed(self):
+ # Handle selection changed
+ self.any_selected = len(self.scene().selectedItems()) > 0
+ if self.any_selected:
+ if self.isSelected():
+ self.setOpacity(self.MAX_OPACITY)
+ else:
+ if self.opacity() != self.HOVER_OPACITY:
+ self.setOpacity(self.SELECTION_OPACITY)
+ else:
+ self.setGraphicsEffect(None)
+ self.setOpacity(self.MAX_OPACITY)
+
+ def paint(self, painter, option, widget=None):
+ # Override the default selected appearance
+ if self.isSelected():
+ option.state ^= QtGui.QStyle.State_Selected
+ rect = self.rect()
+ # this must render before overlay due to order in which it's drawn
+ super().paint(painter, option, widget)
+ painter.save()
+ pen = QtGui.QPen(QtGui.QColor(Qt.black))
+ pen.setWidth(4)
+ pen.setJoinStyle(Qt.MiterJoin)
+ painter.setPen(pen)
+ painter.drawRect(rect.adjusted(2, 2, -2, -2))
+ painter.restore()
+ else:
+ super().paint(painter, option, widget)
+
+
+class TreeNode:
+ """A node in the tree structure used to represent the tree adapter
+
+ Parameters
+ ----------
+ label : int
+ The label of the tree node, can be looked up in the original tree.
+ square : Square
+ The square the represents the tree node.
+ parent : TreeNode or object
+ The parent of the current node. In the case of root, an object
+ containing the root label of the tree adapter should be passed.
+ children : tuple of TreeNode, optional
+ All the children that belong to this node.
+
+ """
+
+ def __init__(self, label, square, parent, children=()):
+ self.label = label
+ self.square = square
+ self.parent = parent
+ self.children = children
+ self.graphics_item = None
+
+ def __str__(self):
+ return '({}) -> [{}]'.format(self.parent, self.label)
+
+
+class PythagorasTree:
+ """Pythagoras tree.
+
+ Contains all the logic that converts a given tree adapter to a tree
+ consisting of node classes.
+
+ """
+
+ def __init__(self):
+ # store the previous angles of each square children so that slopes can
+ # be computed
+ self._slopes = defaultdict(list)
+
+ def pythagoras_tree(self, tree, node, square):
+ """Get the Pythagoras tree representation in a graph like view.
+
+ Constructs a graph using TreeNode into a tree structure. Each node in
+ graph contains the information required to plot the the tree.
+
+ Parameters
+ ----------
+ tree : TreeAdapter
+ A tree adapter instance where the original tree is stored.
+ node : int
+ The node label, the root node is denoted with 0.
+ square : Square
+ The initial square which will represent the root of the tree.
+
+ Returns
+ -------
+ TreeNode
+ The root node which contains the rest of the tree.
+
+ """
+ # make sure to clear out any old slopes if we are drawing a new tree
+ if node == tree.root:
+ self._slopes.clear()
+
+ children = tuple(
+ self._compute_child(tree, square, child)
+ for child in tree.children(node)
+ )
+ # make sure to pass a reference to parent to each child
+ obj = TreeNode(node, square, tree.parent(node), children)
+ # mutate the existing data stored in the created tree node
+ for c in children:
+ c.parent = obj
+ return obj
+
+ def _compute_child(self, tree, parent_square, node):
+ """Compute all the properties for a single child.
+
+ Parameters
+ ----------
+ tree : TreeAdapter
+ A tree adapter instance where the original tree is stored.
+ parent_square : Square
+ The parent square of the given child.
+ node : int
+ The node label of the child.
+
+ Returns
+ -------
+ TreeNode
+ The tree node representation of the given child with the computed
+ subtree.
+
+ """
+ weight = tree.weight(node)
+ # the angle of the child from its parent
+ alpha = weight * pi
+ # the child side length
+ length = parent_square.length * sin(alpha / 2)
+ # the sum of the previous anlges
+ prev_angles = sum(self._slopes[parent_square])
+
+ center = self._compute_center(
+ parent_square, length, alpha, prev_angles
+ )
+ # the angle of the square is dependent on the parent, the current
+ # angle and the previous angles. Subtract PI/2 so it starts drawing at
+ # 0rads.
+ angle = parent_square.angle - pi / 2 + prev_angles + alpha / 2
+ square = Square(center, length, angle)
+
+ self._slopes[parent_square].append(alpha)
+
+ return self.pythagoras_tree(tree, node, square)
+
+ def _compute_center(self, initial_square, length, alpha, base_angle=0):
+ """Compute the central point of a child square.
+
+ Parameters
+ ----------
+ initial_square : Square
+ The parent square representation where we will be drawing from.
+ length : float
+ The length of the side of the new square (the one we are computing
+ the center for).
+ alpha : float
+ The angle that defines the size of our new square (in radians).
+ base_angle : float, optional
+ If the square we want to find the center for is not the first child
+ i.e. its edges does not touch the base square, then we need the
+ initial angle that will act as the starting point for the new
+ square.
+
+ Returns
+ -------
+ Point
+ The central point to the new square.
+
+ """
+ parent_center, parent_length, parent_angle = initial_square
+ # get the point on the square side that will be the rotation origin
+ t0 = self._get_point_on_square_edge(
+ parent_center, parent_length, parent_angle)
+ # get the edge point that we will rotate around t0
+ square_diagonal_length = sqrt(2 * parent_length ** 2)
+ edge = self._get_point_on_square_edge(
+ parent_center, square_diagonal_length, parent_angle - pi / 4)
+ # if the new square is not the first child, we need to rotate the edge
+ if base_angle != 0:
+ edge = self._rotate_point(edge, t0, base_angle)
+
+ # rotate the edge point to the correct spot
+ t1 = self._rotate_point(edge, t0, alpha)
+
+ # calculate the middle point between the rotated point and edge
+ t2 = Point((t1.x + edge.x) / 2, (t1.y + edge.y) / 2)
+ # calculate the slope of the new square
+ slope = parent_angle - pi / 2 + alpha / 2
+ # using this data, we can compute the square center
+ return self._get_point_on_square_edge(t2, length, slope + base_angle)
+
+ @staticmethod
+ def _rotate_point(point, around, alpha):
+ """Rotate a point around another point by some angle.
+
+ Parameters
+ ----------
+ point : Point
+ The point to rotate.
+ around : Point
+ The point to perform rotation around.
+ alpha : float
+ The angle to rotate by (in radians).
+
+ Returns
+ -------
+ Point:
+ The rotated point.
+
+ """
+ temp = Point(point.x - around.x, point.y - around.y)
+ temp = Point(
+ temp.x * cos(alpha) - temp.y * sin(alpha),
+ temp.x * sin(alpha) + temp.y * cos(alpha)
+ )
+ return Point(temp.x + around.x, temp.y + around.y)
+
+ @staticmethod
+ def _get_point_on_square_edge(center, length, angle):
+ """Calculate the central point on the drawing edge of the given square.
+
+ Parameters
+ ----------
+ center : Point
+ The square center point.
+ length : float
+ The square side length.
+ angle : float
+ The angle of the square.
+
+ Returns
+ -------
+ Point
+ A point on the center of the drawing edge of the given square.
+
+ """
+ return Point(
+ center.x + length / 2 * cos(angle),
+ center.y + length / 2 * sin(angle)
+ )
diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py
new file mode 100644
index 00000000000..c0f7d05ed40
--- /dev/null
+++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py
@@ -0,0 +1,59 @@
+import math
+import unittest
+import Orange.widgets
+
+from Orange.widgets.visualize.pythagorastreeviewer import (
+ PythagorasTree,
+ Point,
+ Square,
+)
+
+
+class TestPythagorasTree(unittest.TestCase):
+ def setUp(self):
+ self.builder = PythagorasTree()
+
+ def test_get_point_on_square_edge_with_no_angle(self):
+ point = self.builder._get_point_on_square_edge(
+ center=Point(0, 0), length=2, angle=0
+ )
+ expected_point = Point(1, 0)
+ self.assertAlmostEqual(point.x, expected_point.x, places=1)
+ self.assertAlmostEqual(point.y, expected_point.y, places=1)
+
+ def test_get_point_on_square_edge_with_non_zero_angle(self):
+ point = self.builder._get_point_on_square_edge(
+ center=Point(2.7, 2.77), length=1.65, angle=math.radians(20.97)
+ )
+ expected_point = Point(3.48, 3.07)
+ self.assertAlmostEqual(point.x, expected_point.x, places=1)
+ self.assertAlmostEqual(point.y, expected_point.y, places=1)
+
+ def test_compute_center_with_simple_square_angle(self):
+ initial_square = Square(Point(0, 0), length=2, angle=math.pi / 2)
+ point = self.builder._compute_center(
+ initial_square, length=1.13, alpha=math.radians(68.57))
+ expected_point = Point(1.15, 1.78)
+ self.assertAlmostEqual(point.x, expected_point.x, places=1)
+ self.assertAlmostEqual(point.y, expected_point.y, places=1)
+
+ def test_compute_center_with_complex_square_angle(self):
+ initial_square = Square(
+ Point(1.5, 1.5), length=2.24, angle=math.radians(63.43)
+ )
+ point = self.builder._compute_center(
+ initial_square, length=1.65, alpha=math.radians(95.06))
+ expected_point = Point(3.48, 3.07)
+ self.assertAlmostEqual(point.x, expected_point.x, places=1)
+ self.assertAlmostEqual(point.y, expected_point.y, places=1)
+
+ def test_compute_center_with_complex_square_angle_with_base_angle(self):
+ initial_square = Square(
+ Point(1.5, 1.5), length=2.24, angle=math.radians(63.43)
+ )
+ point = self.builder._compute_center(
+ initial_square, length=1.51, alpha=math.radians(180 - 95.06),
+ base_angle=math.radians(95.06))
+ expected_point = Point(1.43, 3.98)
+ self.assertAlmostEqual(point.x, expected_point.x, places=1)
+ self.assertAlmostEqual(point.y, expected_point.y, places=1)
diff --git a/Orange/widgets/visualize/widgetutils/__init__.py b/Orange/widgets/visualize/widgetutils/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/Orange/widgets/visualize/widgetutils/common/__init__.py b/Orange/widgets/visualize/widgetutils/common/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/Orange/widgets/visualize/widgetutils/common/owgrid.py b/Orange/widgets/visualize/widgetutils/common/owgrid.py
new file mode 100644
index 00000000000..6e091288177
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/common/owgrid.py
@@ -0,0 +1,280 @@
+from itertools import zip_longest
+
+from PyQt4 import QtGui, QtCore
+from PyQt4.QtCore import Qt
+
+
+class GridItem(QtGui.QGraphicsWidget):
+ """The base class for grid items, takes care of positioning in grid.
+
+ Parameters
+ ----------
+ widget : QtGui.QGraphicsWidget
+ parent : QtGui.QGraphicsWidget
+
+ See Also
+ --------
+ OWGrid
+ SelectableGridItem
+ ZoomableGridItem
+
+ """
+
+ def __init__(self, widget, parent=None, **_):
+ super().__init__(parent)
+ # For some reason, the super constructor is not setting the parent
+ self.setParent(parent)
+
+ self.widget = widget
+ if hasattr(self.widget, 'setParent'):
+ self.widget.setParentItem(self)
+ self.widget.setParent(self)
+
+ # Move the child widget to (0, 0) so that bounding rects match up
+ # This is needed because the bounding rect is caluclated with the size
+ # hint from (0, 0), regardless of any method override
+ rect = self.widget.boundingRect()
+ self.widget.moveBy(-rect.topLeft().x(), -rect.topLeft().y())
+
+ def boundingRect(self):
+ return QtCore.QRectF(QtCore.QPointF(0, 0),
+ self.widget.boundingRectoundingRect().size())
+
+ def sizeHint(self, size_hint, size_constraint=None, **kwargs):
+ return self.widget.sizeHint(size_hint, size_constraint, **kwargs)
+
+
+class SelectableGridItem(GridItem):
+ """Makes a grid item selectable.
+
+ Parameters
+ ----------
+ widget : QtGui.QGraphicsWidget
+ parent : QtGui.QgraphicsWidget
+
+ See Also
+ --------
+ OWGrid
+ GridItem
+ ZoomableGridItem
+
+ """
+
+ def __init__(self, widget, parent=None, **kwargs):
+ super().__init__(widget, parent, **kwargs)
+
+ self.setFlags(QtGui.QGraphicsWidget.ItemIsSelectable)
+
+ def paint(self, painter, options, widget=None):
+ super().paint(painter, options, widget)
+ rect = self.boundingRect()
+ painter.save()
+ if self.isSelected():
+ painter.setPen(QtGui.QPen(QtGui.QColor(125, 162, 206, 192)))
+ painter.setBrush(QtGui.QBrush(QtGui.QColor(217, 232, 252, 192)))
+ painter.drawRoundedRect(QtCore.QRectF(
+ rect.topLeft(), self.geometry().size()), 3, 3)
+ else:
+ painter.setPen(QtGui.QPen(QtGui.QColor('#ebebeb')))
+ painter.drawRoundedRect(QtCore.QRectF(
+ rect.topLeft(), self.geometry().size()), 3, 3)
+ painter.restore()
+
+
+class ZoomableGridItem(GridItem):
+ """Makes a grid item "zoomable" through the `set_max_size` method.
+
+ Notes
+ -----
+ .. Note:: This grid item will override any bounding box or size hint
+ defined in the class hierarchy with its own.
+ .. Note:: This makes the grid item square.
+
+ Parameters
+ ----------
+ widget : QtGui.QGraphicsWidget
+ parent : QtGui.QGraphicsWidget
+ max_size : int
+ The maximum size of the grid item.
+
+ See Also
+ --------
+ OWGrid
+ GridItem
+ SelectableGridItem
+
+ """
+
+ def __init__(self, widget, parent=None, max_size=150, **kwargs):
+ self._max_size = QtCore.QSizeF(max_size, max_size)
+ # We store the offsets from the top left corner to move widget properly
+ self.__offset_x = self.__offset_y = 0
+
+ super().__init__(widget, parent, **kwargs)
+
+ self._resize_widget()
+
+ def set_max_size(self, max_size):
+ self.widget.resetTransform()
+ self._max_size = QtCore.QSizeF(max_size, max_size)
+ self._resize_widget()
+
+ def _resize_widget(self):
+ w = self.widget
+ own_hint = self.sizeHint(Qt.PreferredSize)
+
+ scale_w = own_hint.width() / w.boundingRect().width()
+ scale_h = own_hint.height() / w.boundingRect().height()
+ scale = scale_w if scale_w < scale_h else scale_h
+
+ # Move the widget back to origin then perfom transformations
+ self.widget.moveBy(-self.__offset_x, -self.__offset_y)
+ # Move the tranform origin to top left, so it stays in place when
+ # scaling
+ w.setTransformOriginPoint(w.boundingRect().topLeft())
+ w.setScale(scale)
+ # Then, move the scaled widget to the center of the bounding box
+ own_rect = self.boundingRect()
+ self.__offset_x = (own_rect.width() - w.boundingRect().width() *
+ scale) / 2
+ self.__offset_y = (own_rect.height() - w.boundingRect().height() *
+ scale) / 2
+ self.widget.moveBy(self.__offset_x, self.__offset_y)
+ # Finally, tell the world you've changed
+ self.updateGeometry()
+
+ def boundingRect(self):
+ return QtCore.QRectF(QtCore.QPointF(0, 0), self._max_size)
+
+ def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
+ return self._max_size
+
+
+class OWGrid(QtGui.QGraphicsWidget):
+ """Responsive grid layout widget.
+
+ Manages grid items for various window sizes.
+
+ Accepts grid items as items.
+
+ Parameters
+ ----------
+ parent : QtGui.QGraphicsWidget
+
+ Examples
+ --------
+ >>> grid = OWGrid()
+
+ It's a good idea to define what you want your grid items to do. For this
+ example, we will make them selectable and zoomable, so we define a class
+ that inherits from both:
+ >>> class MyGridItem(SelectableGridItem, ZoomableGridItem):
+ >>> pass
+
+ We then take a list of items and wrap them into our new `MyGridItem`
+ instances.
+ >>> items = [QtGui.QGraphicsRectItem(0, 0, 10, 10)]
+ >>> grid_items = [MyGridItem(i, grid) for i in items]
+
+ We can then set the items to be displayed
+ >>> grid.set_items(grid_items)
+
+ """
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+
+ self.setSizePolicy(QtGui.QSizePolicy.Maximum,
+ QtGui.QSizePolicy.Maximum)
+ self.setContentsMargins(10, 10, 10, 10)
+
+ self.__layout = QtGui.QGraphicsGridLayout()
+ self.__layout.setContentsMargins(0, 0, 0, 0)
+ self.__layout.setSpacing(10)
+ self.setLayout(self.__layout)
+
+ def set_items(self, items):
+ for i, item in enumerate(items):
+ # Place the items in some arbitrary order - they will be rearranged
+ # before user sees this ordering
+ self.__layout.addItem(item, i, 0)
+
+ def setGeometry(self, rect):
+ super().setGeometry(rect)
+ self.reflow(self.size().width())
+
+ def reflow(self, width):
+ """Recalculate the layout and reposition the elements so they fit.
+
+ Parameters
+ ----------
+ width : int
+ The maximum width of the grid.
+
+ Returns
+ -------
+
+ """
+ # When setting the geometry when opened, the layout doesn't yet exist
+ if self.layout() is None:
+ return
+
+ grid = self.__layout
+
+ left, right, *_ = self.getContentsMargins()
+ width -= left + right
+
+ # Get size hints with 32 as the minimum size for each cell
+ widths = [max(64, h.width()) for h in self._hints(Qt.PreferredSize)]
+ ncol = self._fit_n_cols(widths, grid.horizontalSpacing(), width)
+
+ # The number of columns is already optimal
+ if ncol == grid.columnCount():
+ return
+
+ # remove all items from the layout, then re-add them back in updated
+ # positions
+ items = self._items()
+
+ for item in items:
+ grid.removeItem(item)
+
+ for i, item in enumerate(items):
+ grid.addItem(item, i // ncol, i % ncol)
+ grid.setAlignment(item, Qt.AlignCenter)
+
+ def clear(self):
+ for item in self._items():
+ self.__layout.removeItem(item)
+ item.setParent(None)
+
+ @staticmethod
+ def _fit_n_cols(widths, spacing, constraint):
+
+ def sliced(seq, n_col):
+ """Slice the widths into n lists that contain their respective
+ widths. E.g. [5, 5, 5], 2 => [[5, 5], [5]]"""
+ return [seq[i:i + n_col] for i in range(0, len(seq), n_col)]
+
+ def flow_width(widths, spacing, ncol):
+ w = sliced(widths, ncol)
+ col_widths = map(max, zip_longest(*w, fillvalue=0))
+ return sum(col_widths) + (ncol - 1) * spacing
+
+ ncol_best = 1
+ for ncol in range(2, len(widths) + 1):
+ width = flow_width(widths, spacing, ncol)
+ if width <= constraint:
+ ncol_best = ncol
+ else:
+ break
+
+ return ncol_best
+
+ def _items(self):
+ if not self.__layout:
+ return []
+ return [self.__layout.itemAt(i) for i in range(self.__layout.count())]
+
+ def _hints(self, which):
+ return [item.sizeHint(which) for item in self._items()]
diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/widgetutils/common/owlegend.py
new file mode 100644
index 00000000000..1047d7a43b0
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/common/owlegend.py
@@ -0,0 +1,644 @@
+"""
+Legend classes to use with `QGraphicsScene` objects.
+"""
+import numpy as np
+from PyQt4 import QtGui, QtCore
+from PyQt4.QtCore import Qt
+
+
+class Anchorable(QtGui.QGraphicsWidget):
+
+ __corners = ['topLeft', 'topRight', 'bottomLeft', 'bottomRight']
+ TOP_LEFT, TOP_RIGHT, BOTTOM_LEFT, BOTTOM_RIGHT = __corners
+
+ def __init__(self, parent=None, corner='bottomRight', offset=(10, 10)):
+ super().__init__(parent)
+
+ self.__corner_str = corner if corner in self.__corners else None
+ # The flag indicates whether or not the item has been drawn on yet.
+ # This is useful for determining the initial offset, due to the fact
+ # that dimensions are available in the resize event, which can occur
+ # multiple times.
+ self.__has_been_drawn = False
+
+ if isinstance(offset, tuple) or isinstance(offset, list):
+ assert len(offset) == 2
+ self.__offset = QtCore.QPoint(*offset)
+ elif isinstance(offset, QtCore.QPoint):
+ self.__offset = offset
+
+ def moveEvent(self, ev):
+ super().moveEvent(ev)
+ # This check is needed because simply resizing the window will cause
+ # the item to move and trigger a `moveEvent` therefore we need to check
+ # that the movement was done intentionally by the user using the mouse
+ if QtGui.QApplication.mouseButtons() == Qt.LeftButton:
+ self.recalculate_offset()
+
+ def resizeEvent(self, ev):
+ # When the item is first shown, we need to update its position
+ super().resizeEvent(ev)
+ if not self.__has_been_drawn:
+ self.__offset = self.__calculate_actual_offset(self.__offset)
+ self.update_pos()
+ self.__has_been_drawn = True
+
+ def showEvent(self, ev):
+ # When the item is first shown, we need to update its position
+ super().showEvent(ev)
+ self.update_pos()
+
+ def recalculate_offset(self):
+ # This is called whenever the item is being moved and needs to
+ # recalculate its offset
+ view = self.__get_view()
+ # Get the view box and position of legend relative to the view,
+ # not the scene
+ pos = view.mapFromScene(self.pos())
+ view_box = self.__usable_viewbox()
+
+ self.__corner_str = self.__get_closest_corner()
+ viewbox_corner = getattr(view_box, self.__corner_str)()
+
+ self.__offset = viewbox_corner - pos
+
+ def update_pos(self):
+ # This is called whenever something happened with the view that caused
+ # this item to move from its anchored position, so we have to adjust
+ # the position to maintain the effect of being anchored
+ view = self.__get_view()
+ if self.__corner_str and view is not None:
+ box = self.__usable_viewbox()
+ corner = getattr(box, self.__corner_str)()
+ new_pos = corner - self.__offset
+ self.setPos(view.mapToScene(new_pos))
+
+ def __calculate_actual_offset(self, offset):
+ """Take the offset specified in the constructor and calculate the
+ actual offset from the top left corner of the item so positioning can
+ be done correctly."""
+ off_x, off_y = offset.x(), offset.y()
+ w, h = self.boundingRect().width(), self.boundingRect().height()
+ if self.__corner_str == self.TOP_LEFT:
+ return QtCore.QPoint(-off_x, -off_y)
+ elif self.__corner_str == self.TOP_RIGHT:
+ return QtCore.QPoint(off_x + w, -off_y)
+ elif self.__corner_str == self.BOTTOM_RIGHT:
+ return QtCore.QPoint(off_x + w, off_y + h)
+ elif self.__corner_str == self.BOTTOM_LEFT:
+ return QtCore.QPoint(-off_x, off_y + h)
+
+ def __get_closest_corner(self):
+ view = self.__get_view()
+ # Get the view box and position of legend relative to the view,
+ # not the scene
+ pos = view.mapFromScene(self.pos())
+ legend_box = QtCore.QRect(pos, self.size().toSize())
+ view_box = QtCore.QRect(QtCore.QPoint(0, 0), view.size())
+
+ def distance(t1, t2):
+ # 2d euclidean distance
+ return np.sqrt((t1.x() - t2.x()) ** 2 + (t1.y() - t2.y()) ** 2)
+
+ distances = [
+ (distance(getattr(view_box, corner)(),
+ getattr(legend_box, corner)()), corner)
+ for corner in self.__corners
+ ]
+ _, corner = min(distances)
+ return corner
+
+ def __get_own_corner(self):
+ view = self.__get_view()
+ pos = view.mapFromScene(self.pos())
+ legend_box = QtCore.QRect(pos, self.size().toSize())
+ return getattr(legend_box, self.__corner_str)()
+
+ def __get_view(self):
+ if self.scene() is not None:
+ view, = self.scene().views()
+ return view
+ else:
+ return None
+
+ def __usable_viewbox(self):
+ view = self.__get_view()
+
+ if view.horizontalScrollBar().isVisible():
+ h = view.horizontalScrollBar().size().height()
+ else:
+ h = 0
+
+ if view.verticalScrollBar().isVisible():
+ w = view.verticalScrollBar().size().width()
+ else:
+ w = 0
+
+ size = view.size() - QtCore.QSize(w, h)
+ return QtCore.QRect(QtCore.QPoint(0, 0), size)
+
+
+class AnchorableGraphicsView(QtGui.QGraphicsView):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.horizontalScrollBar().valueChanged.connect(
+ self.update_anchored_items)
+ self.verticalScrollBar().valueChanged.connect(
+ self.update_anchored_items)
+
+ def resizeEvent(self, ev):
+ super().resizeEvent(ev)
+ self.update_anchored_items()
+
+ def mousePressEvent(self, ev):
+ super().mousePressEvent(ev)
+ self.update_anchored_items()
+
+ def wheelEvent(self, ev):
+ super().wheelEvent(ev)
+ self.update_anchored_items()
+
+ def mouseMoveEvent(self, ev):
+ super().mouseMoveEvent(ev)
+ self.update_anchored_items()
+
+ def update_anchored_items(self):
+ for item in self.__anchorable_items():
+ item.update_pos()
+
+ def __anchorable_items(self):
+ return [i for i in self.scene().items() if isinstance(i, Anchorable)]
+
+
+class ColorIndicator(QtGui.QGraphicsWidget):
+ pass
+
+
+class LegendItemSquare(ColorIndicator):
+ """Legend square item.
+
+ The legend square item is a small colored square image that can be plugged
+ into the legend in front of the text object.
+
+ This should only really be used in conjunction with ˙LegendItem˙.
+
+ Parameters
+ ----------
+ color : QtGui.QColor
+ The color of the square.
+ parent : QtGui.QGraphicsItem
+
+ See Also
+ --------
+ LegendItemCircle
+
+ """
+
+ SIZE = QtCore.QSizeF(12, 12)
+
+ def __init__(self, color, parent):
+ super().__init__(parent)
+
+ height, width = self.SIZE.height(), self.SIZE.width()
+ self.__square = QtGui.QGraphicsRectItem(0, 0, height, width)
+ self.__square.setBrush(QtGui.QBrush(color))
+ self.__square.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0)))
+ self.__square.setParentItem(self)
+
+ def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
+ return QtCore.QSizeF(self.__square.boundingRect().size())
+
+
+class LegendItemCircle(ColorIndicator):
+ """Legend circle item.
+
+ The legend circle item is a small colored circle image that can be plugged
+ into the legend in front of the text object.
+
+ This should only really be used in conjunction with ˙LegendItem˙.
+
+ Parameters
+ ----------
+ color : QtGui.QColor
+ The color of the square.
+ parent : QtGui.QGraphicsItem
+
+ See Also
+ --------
+ LegendItemSquare
+
+ """
+
+ SIZE = QtCore.QSizeF(12, 12)
+
+ def __init__(self, color, parent):
+ super().__init__(parent)
+
+ height, width = self.SIZE.height(), self.SIZE.width()
+ self.__circle = QtGui.QGraphicsEllipseItem(0, 0, height, width)
+ self.__circle.setBrush(QtGui.QBrush(color))
+ self.__circle.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0)))
+ self.__circle.setParentItem(self)
+
+ def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
+ return QtCore.QSizeF(self.__circle.boundingRect().size())
+
+
+class LegendItemTitle(QtGui.QGraphicsWidget):
+ """Legend item title - the text displayed in the legend.
+
+ This should only really be used in conjunction with ˙LegendItem˙.
+
+ Parameters
+ ----------
+ text : str
+ parent : QtGui.QGraphicsItem
+ font : QtGui.QFont
+ This
+
+ """
+
+ def __init__(self, text, parent, font):
+ super().__init__(parent)
+
+ self.__text = QtGui.QGraphicsTextItem(text.title())
+ self.__text.setParentItem(self)
+ self.__text.setFont(font)
+
+ def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
+ return QtCore.QSizeF(self.__text.boundingRect().size())
+
+
+class LegendItem(QtGui.QGraphicsLinearLayout):
+ """Legend item - one entry in the legend.
+
+ This represents one entry in the legend i.e. a color indicator and the text
+ beside it.
+
+ Parameters
+ ----------
+ color : QtGui.QColor
+ The color that the entry will represent.
+ title : str
+ The text that will be displayed for the color.
+ parent : QtGui.QGraphicsItem
+ color_indicator_cls : ColorIndicator
+ The type of `ColorIndicator` that will be used for the color.
+ font : QtGui.QFont, optional
+
+ """
+
+ def __init__(self, color, title, parent, color_indicator_cls, font=None):
+ super().__init__()
+
+ self.__parent = parent
+ self.__color_indicator = color_indicator_cls(color, parent)
+ self.__title_label = LegendItemTitle(title, parent, font=font)
+
+ self.addItem(self.__color_indicator)
+ self.addItem(self.__title_label)
+
+ # Make sure items are aligned properly, since the color box and text
+ # won't be the same height.
+ self.setAlignment(self.__color_indicator, Qt.AlignCenter)
+ self.setAlignment(self.__title_label, Qt.AlignCenter)
+ self.setContentsMargins(0, 0, 0, 0)
+ self.setSpacing(5)
+
+
+class LegendGradient(QtGui.QGraphicsWidget):
+ """Gradient widget.
+
+ A gradient square bar that can be used to display continuous values.
+
+ Parameters
+ ----------
+ palette : iterable[QtGui.QColor]
+ parent : QtGui.QGraphicsWidget
+ orientation : Qt.Orientation
+
+ Notes
+ -----
+ .. Note:: While the gradient does support any number of colors, any more
+ than 3 is not very readable. This should not be a problem, since Orange
+ only implements 2 or 3 colors.
+
+ """
+
+ # Default sizes (assume gradient is vertical by default)
+ GRADIENT_WIDTH = 20
+ GRADIENT_HEIGHT = 150
+
+ def __init__(self, palette, parent, orientation):
+ super().__init__(parent)
+
+ self.__gradient = QtGui.QLinearGradient()
+ num_colors = len(palette)
+ for idx, stop in enumerate(palette):
+ self.__gradient.setColorAt(idx * (1. / (num_colors - 1)), stop)
+
+ # We need to tell the gradient where it's start and stop points are
+ self.__gradient.setStart(QtCore.QPointF(0, 0))
+ if orientation == Qt.Vertical:
+ final_stop = QtCore.QPointF(0, self.GRADIENT_HEIGHT)
+ else:
+ final_stop = QtCore.QPointF(self.GRADIENT_HEIGHT, 0)
+ self.__gradient.setFinalStop(final_stop)
+
+ # Get the appropriate rectangle dimensions based on orientation
+ if orientation == Qt.Vertical:
+ w, h = self.GRADIENT_WIDTH, self.GRADIENT_HEIGHT
+ elif orientation == Qt.Horizontal:
+ w, h = self.GRADIENT_HEIGHT, self.GRADIENT_WIDTH
+
+ self.__rect_item = QtGui.QGraphicsRectItem(0, 0, w, h, self)
+ self.__rect_item.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0)))
+ self.__rect_item.setBrush(QtGui.QBrush(self.__gradient))
+
+ def sizeHint(self, size_hint, size_constraint=None, *args, **kwargs):
+ return QtCore.QSizeF(self.__rect_item.boundingRect().size())
+
+
+class ContinuousLegendItem(QtGui.QGraphicsLinearLayout):
+ """Continuous legend item.
+
+ Contains a gradient bar with the color ranges, as well as two labels - one
+ on each side of the gradient bar.
+
+ Parameters
+ ----------
+ palette : iterable[QtGui.QColor]
+ values : iterable[float...]
+ The number of values must match the number of colors in passed in the
+ color palette.
+ parent : QtGui.QGraphicsWidget
+ font : QtGui.QFont
+ orientation : Qt.Orientation
+
+ """
+
+ def __init__(self, palette, values, parent, font=None,
+ orientation=Qt.Vertical):
+ if orientation == Qt.Vertical:
+ super().__init__(Qt.Horizontal)
+ else:
+ super().__init__(Qt.Vertical)
+
+ self.__parent = parent
+ self.__palette = palette
+ self.__values = values
+
+ self.__gradient = LegendGradient(palette, parent, orientation)
+ self.__labels_layout = QtGui.QGraphicsLinearLayout(orientation)
+
+ str_vals = self._format_values(values)
+
+ self.__start_label = LegendItemTitle(str_vals[0], parent, font=font)
+ self.__end_label = LegendItemTitle(str_vals[1], parent, font=font)
+ self.__labels_layout.addItem(self.__start_label)
+ self.__labels_layout.addStretch(1)
+ self.__labels_layout.addItem(self.__end_label)
+
+ # Gradient should be to the left, then labels on the right if vertical
+ if orientation == Qt.Vertical:
+ self.addItem(self.__gradient)
+ self.addItem(self.__labels_layout)
+ # Gradient should be on the bottom, labels on top if horizontal
+ elif orientation == Qt.Horizontal:
+ self.addItem(self.__labels_layout)
+ self.addItem(self.__gradient)
+
+ @staticmethod
+ def _format_values(values):
+ """Get the formatted values to output."""
+ return ['{:.3f}'.format(v) for v in values]
+
+
+class Legend(Anchorable):
+ """Base legend class.
+
+ This class provides common attributes for any legend derivates:
+ - Behaviour on `QGraphicsScene`
+ - Appearance of legend
+
+ If you have access to the `domain` property, the `LegendBuilder` class
+ can be used to automatically build a legend for you.
+
+ Parameters
+ ----------
+ parent : QtGui.QGraphicsItem, optional
+ orientation : Qt.Orientation, optional
+ The default orientation is vertical
+ domain : Orange.data.domain.Domain, optional
+ This field is left optional as in some cases, we may want to simply
+ pass in a list that represents the legend.
+ items : Iterable[QtGui.QColor, str]
+ bg_color : QtGui.QColor, optional
+ font : QtGui.QFont, optional
+ color_indicator_cls : ColorIndicator
+ The color indicator class that will be used to render the indicators.
+
+ See Also
+ --------
+ OWDiscreteLegend
+ OWContinuousLegend
+ OWContinuousLegend
+
+ Notes
+ -----
+ .. Warning:: If the domain parameter is supplied, the items parameter will
+ be ignored.
+
+ """
+
+ def __init__(self, parent=None, orientation=Qt.Vertical, domain=None,
+ items=None, bg_color=QtGui.QColor(232, 232, 232, 196),
+ font=None, color_indicator_cls=LegendItemSquare, **kwargs):
+ super().__init__(parent, **kwargs)
+
+ self.orientation = orientation
+ self.bg_color = QtGui.QBrush(bg_color)
+ self.color_indicator_cls = color_indicator_cls
+
+ # Set default font if none is given
+ if font is None:
+ self.font = QtGui.QFont()
+ self.font.setPointSize(10)
+ else:
+ self.font = font
+
+ self.setFlags(QtGui.QGraphicsWidget.ItemIsMovable |
+ QtGui.QGraphicsItem.ItemIgnoresTransformations)
+
+ if domain is not None:
+ self.set_domain(domain)
+ elif items is not None:
+ self.set_items(items)
+
+ def _clear_layout(self):
+ self._layout = None
+ for child in self.children():
+ child.setParent(None)
+
+ def _setup_layout(self):
+ self._clear_layout()
+
+ self._layout = QtGui.QGraphicsLinearLayout(self.orientation)
+ self._layout.setContentsMargins(10, 5, 10, 5)
+ # If horizontal, there needs to be horizontal space between the items
+ if self.orientation == Qt.Horizontal:
+ self._layout.setSpacing(10)
+ # If vertical spacing, vertical space is provided by child layouts
+ else:
+ self._layout.setSpacing(0)
+ self.setLayout(self._layout)
+
+ def set_domain(self, domain):
+ """Handle receiving the domain object.
+
+ Parameters
+ ----------
+ domain : Orange.data.domain.Domain
+
+ Returns
+ -------
+
+ Raises
+ ------
+ AttributeError
+ If the domain does not contain the correct type of class variable.
+
+ """
+ raise NotImplemented()
+
+ def set_items(self, values):
+ """Handle receiving an array of items.
+
+ Parameters
+ ----------
+ values : iterable[object, QtGui.QColor]
+
+ Returns
+ -------
+
+ """
+ raise NotImplemented()
+
+ @staticmethod
+ def _convert_to_color(obj):
+ if isinstance(obj, QtGui.QColor):
+ return obj
+ elif isinstance(obj, tuple) or isinstance(obj, list):
+ assert len(obj) in (3, 4)
+ return QtGui.QColor(*obj)
+ else:
+ return QtGui.QColor(obj)
+
+ def paint(self, painter, options, widget=None):
+ painter.save()
+ pen = QtGui.QPen(QtGui.QColor(196, 197, 193, 200), 1)
+ brush = QtGui.QBrush(QtGui.QColor(self.bg_color))
+
+ painter.setPen(pen)
+ painter.setBrush(brush)
+ painter.drawRect(self.contentsRect())
+ painter.restore()
+
+
+class OWDiscreteLegend(Legend):
+ """Discrete legend.
+
+ See Also
+ --------
+ Legend
+ OWContinuousLegend
+
+ """
+
+ def set_domain(self, domain):
+ class_var = domain.class_var
+
+ if not class_var.is_discrete:
+ raise AttributeError('[OWDiscreteLegend] The class var provided '
+ 'was not discrete.')
+
+ self.set_items(zip(class_var.values, class_var.colors.tolist()))
+
+ def set_items(self, values):
+ self._setup_layout()
+ for class_name, color in values:
+ legend_item = LegendItem(
+ color=self._convert_to_color(color),
+ title=class_name,
+ parent=self,
+ color_indicator_cls=self.color_indicator_cls,
+ font=self.font
+ )
+ self._layout.addItem(legend_item)
+
+
+class OWContinuousLegend(Legend):
+ """Continuous legend.
+
+ See Also
+ --------
+ Legend
+ OWDiscreteLegend
+
+ """
+
+ def __init__(self, *args, **kwargs):
+ # Variables used in the `set_` methods must be set before calling super
+ self.__range = kwargs.get('range', ())
+
+ super().__init__(*args, **kwargs)
+
+ self._layout.setContentsMargins(10, 10, 10, 10)
+
+ def set_domain(self, domain):
+ class_var = domain.class_var
+
+ if not class_var.is_continuous:
+ raise AttributeError('[OWContinuousLegend] The class var provided '
+ 'was not continuous.')
+
+ # The first and last values must represent the range, the rest should
+ # be dummy variables, as they are not shown anywhere
+ values = self.__range
+
+ start, end, pass_through_black = class_var.colors
+ # If pass through black, push black in between and add index to vals
+ if pass_through_black:
+ colors = [self._convert_to_color(c) for c
+ in [start, '#000000', end]]
+ values.insert(1, -1)
+ else:
+ colors = [self._convert_to_color(c) for c in [start, end]]
+
+ self.set_items(list(zip(values, colors)))
+
+ def set_items(self, values):
+ vals, colors = list(zip(*values))
+
+ # If the orientation is vertical, it makes more sense for the smaller
+ # value to be shown on the bottom
+ if self.orientation == Qt.Vertical and vals[0] < vals[len(vals) - 1]:
+ colors, vals = list(reversed(colors)), list(reversed(vals))
+
+ self._setup_layout()
+ self._layout.addItem(ContinuousLegendItem(
+ palette=colors,
+ values=vals,
+ parent=self,
+ font=self.font,
+ orientation=self.orientation
+ ))
+
+
+class OWBinnedContinuousLegend(Legend):
+ def set_domain(self, domain):
+ pass
+
+ def set_items(self, values):
+ pass
diff --git a/Orange/widgets/visualize/widgetutils/common/scene.py b/Orange/widgets/visualize/widgetutils/common/scene.py
new file mode 100644
index 00000000000..ca10a1993fe
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/common/scene.py
@@ -0,0 +1,25 @@
+from PyQt4 import QtGui
+
+
+class UpdateItemsOnSelectGraphicsScene(QtGui.QGraphicsScene):
+ """Calls the selection_changed method on items.
+
+ Whenever the scene selection changes, this view will call the
+ ˙selection_changed˙ method on any item on the scene.
+
+ Notes
+ -----
+ ..Note:: I suspect this is completely unncessary, but have not been able to
+ find a reasonable way to keep the selection logic inside the actual
+ `QGraphicsItem` objects
+
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.selectionChanged.connect(self.__handle_selection)
+
+ def __handle_selection(self):
+ for item in self.items():
+ if hasattr(item, 'selection_changed'):
+ item.selection_changed()
diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/widgetutils/common/view.py
new file mode 100644
index 00000000000..e080828f757
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/common/view.py
@@ -0,0 +1,180 @@
+from itertools import repeat
+
+import numpy as np
+from PyQt4 import QtGui
+from PyQt4.QtCore import Qt
+
+
+class ZoomableGraphicsView(QtGui.QGraphicsView):
+ """Zoomable graphics view.
+
+ Composable graphics view that adds zoom functionality.
+
+ It also handles automatic resizing of content whenever the window is
+ resized.
+
+ Right click will reset the zoom to a factor where the entire scene is
+ visible.
+
+ Parameters
+ ----------
+ scene : QtGui.QGraphicsScene
+ padding : int or tuple, optional
+ Specify the padding around the drawn widgets. Can be an int, or tuple,
+ the tuple can contain either 2 or 4 elements.
+
+ Notes
+ -----
+ - This view will consume wheel scrolling and right mouse click events.
+
+ """
+
+ def __init__(self, scene, padding=(0, 0), **kwargs):
+ self.zoom = 1
+ self.scale_factor = 1 / 16
+ # zoomout limit prevents the zoom factor to become negative, which
+ # results in the canvas being flipped over the x axis
+ self.__zoomout_limit_reached = False
+ # Does the view need to recalculate the initial scale factor
+ self.__needs_to_recalculate_initial = True
+ self.__initial_zoom = -1
+
+ self.__central_widget = None
+ self.__set_padding(padding)
+
+ super().__init__(scene, **kwargs)
+
+ def resizeEvent(self, ev):
+ super().resizeEvent(ev)
+ self.__needs_to_recalculate_initial = True
+
+ def wheelEvent(self, ev):
+ self.__handle_zoom(ev.delta())
+ super().wheelEvent(ev)
+
+ def mousePressEvent(self, ev):
+ # right click resets the zoom factor
+ if ev.button() == Qt.RightButton:
+ self.reset_zoom()
+ super().mousePressEvent(ev)
+
+ def keyPressEvent(self, ev):
+ if ev.key() == Qt.Key_Plus:
+ self.__handle_zoom(1)
+ elif ev.key() == Qt.Key_Minus:
+ self.__handle_zoom(-1)
+
+ super().keyPressEvent(ev)
+
+ def __set_padding(self, padding):
+ # Allow for multiple formats of padding for convenience
+ if isinstance(padding, int):
+ padding = list(repeat(padding, 4))
+ elif isinstance(padding, list) or isinstance(padding, tuple):
+ if len(padding) == 2:
+ padding = (*padding, *padding)
+ else:
+ padding = 0, 0, 0, 0
+
+ l, t, r, b = padding
+ self.__padding = -l, -t, r, b
+
+ def __handle_zoom(self, direction):
+ """Handle zoom event, direction is positive if zooming in, otherwise
+ negative."""
+ if self.__zooming_in(direction):
+ self.__reset_zoomout_limit()
+ if self.__zoomout_limit_reached and self.__zooming_out(direction):
+ return
+
+ self.zoom += np.sign(direction) * self.scale_factor
+ if self.zoom <= 0:
+ self.__zoomout_limit_reached = True
+ self.zoom += self.scale_factor
+ else:
+ self.setTransformationAnchor(self.AnchorUnderMouse)
+ self.setTransform(QtGui.QTransform().scale(self.zoom, self.zoom))
+
+ @staticmethod
+ def __zooming_out(direction):
+ return direction < 0
+
+ def __zooming_in(self, ev):
+ return not self.__zooming_out(ev)
+
+ def __reset_zoomout_limit(self):
+ self.__zoomout_limit_reached = False
+
+ def set_central_widget(self, widget):
+ self.__central_widget = widget
+
+ def central_widget_rect(self):
+ """Get the bounding box of the central widget.
+
+ If a central widget and padding are set, this method calculates the
+ rect containing both of them. This is useful because if the padding was
+ added directly onto the widget, the padding would be rescaled as well.
+
+ If the central widget is not set, return the scene rect instead.
+
+ Returns
+ -------
+ QtCore.QRectF
+
+ """
+ if self.__central_widget is None:
+ return self.scene().itemsBoundingRect().adjusted(*self.__padding)
+ return self.__central_widget.boundingRect().adjusted(*self.__padding)
+
+ def recalculate_and_fit(self):
+ """Recalculate the optimal zoom and fits the content into view.
+
+ Should be called if the scene contents change, so that the optimal zoom
+ can be recalculated.
+
+ Returns
+ -------
+
+ """
+ if self.__central_widget is not None:
+ self.fitInView(self.central_widget_rect(), Qt.KeepAspectRatio)
+ else:
+ self.fitInView(self.scene().sceneRect(), Qt.KeepAspectRatio)
+
+ self.__initial_zoom = self.matrix().m11()
+ self.zoom = self.__initial_zoom
+
+ def reset_zoom(self):
+ """Reset the zoom to the optimal factor."""
+ self.zoom = self.__initial_zoom
+ self.__zoomout_limit_reached = False
+
+ if self.__needs_to_recalculate_initial:
+ self.recalculate_and_fit()
+ else:
+ self.setTransform(QtGui.QTransform().scale(self.zoom, self.zoom))
+
+
+class PannableGraphicsView(QtGui.QGraphicsView):
+ """Pannable graphics view.
+
+ Enables panning the graphics view.
+
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setDragMode(QtGui.QGraphicsView.ScrollHandDrag)
+
+ def enterEvent(self, ev):
+ self.viewport().setCursor(Qt.ArrowCursor)
+ super().enterEvent(ev)
+
+ def mouseReleaseEvent(self, ev):
+ super().mouseReleaseEvent(ev)
+ self.viewport().setCursor(Qt.ArrowCursor)
+
+
+class PreventDefaultWheelEvent(QtGui.QGraphicsView):
+ def wheelEvent(self, ev):
+ ev.accept()
diff --git a/Orange/widgets/visualize/widgetutils/tree/__init__.py b/Orange/widgets/visualize/widgetutils/tree/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py
new file mode 100644
index 00000000000..c113e096f55
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/tree/rules.py
@@ -0,0 +1,181 @@
+class Rule:
+ """The base Rule class for tree rules."""
+
+ def merge_with(self, rule):
+ """Merge the current rule with the given rule.
+
+ Parameters
+ ----------
+ rule : Rule
+
+ Returns
+ -------
+ Rule
+
+ """
+ raise NotImplemented()
+
+
+class DiscreteRule(Rule):
+ """Discrete rule class for handling Indicator rules.
+
+ Parameters
+ ----------
+ attr_name : str
+ eq : bool
+ Should indicate whether or not the rule equals the value or not.
+ value : object
+
+ Examples
+ --------
+ Age = 30
+ >>> rule = DiscreteRule('age', True, 30)
+ Name ≠ John
+ >>> rule = DiscreteRule('name', False, 'John')
+
+ Notes
+ -----
+ - Merging discrete rules is currently not implemented, the new rule is
+ simply returned and a warning is printed to stderr.
+
+ """
+
+ def __init__(self, attr_name, eq, value):
+ self.attr_name = attr_name
+ self.sign = eq
+ self.value = value
+
+ def merge_with(self, rule):
+ # It does not make sense to merge discrete rules, since they can only
+ # be eq or not eq.
+ from sys import stderr
+ print('WARNING: Merged two discrete rules `%s` and `%s`'
+ % (self, rule), file=stderr)
+ return rule
+
+ def __str__(self):
+ return '{} {} {}'.format(
+ self.attr_name, '=' if self.sign else '≠', self.value)
+
+
+class ContinuousRule(Rule):
+ """Continuous rule class for handling numeric rules.
+
+ Parameters
+ ----------
+ attr_name : str
+ gt : bool
+ Should indicate whether the variable must be greater than the value.
+ value : int
+ inclusive : bool, optional
+ Should the variable range include the value or not
+ (LT <> LTE | GT <> GTE). Default is False.
+
+ Examples
+ --------
+ x ≤ 30
+ >>> rule = ContinuousRule('age', False, 30, inclusive=True)
+ x > 30
+ >>> rule = ContinuousRule('age', True, 30)
+
+ Notes
+ -----
+ - Continuous rules can currently only be merged with other continuous
+ rules.
+
+ """
+
+ def __init__(self, attr_name, gt, value, inclusive=False):
+ self.attr_name = attr_name
+ self.sign = gt
+ self.value = value
+ self.inclusive = inclusive
+
+ def merge_with(self, rule):
+ if not isinstance(rule, ContinuousRule):
+ raise NotImplemented('Continuous rules can currently only be '
+ 'merged with other continuous rules')
+ # Handle when both have same sign
+ if self.sign == rule.sign:
+ # When both are GT
+ if self.sign is True:
+ larger = self.value if self.value > rule.value else rule.value
+ return ContinuousRule(self.attr_name, self.sign, larger)
+ # When both are LT
+ else:
+ smaller = self.value if self.value < rule.value else rule.value
+ return ContinuousRule(self.attr_name, self.sign, smaller)
+ # When they have different signs we need to return an interval rule
+ else:
+ lt_rule = self if self.sign is False else rule
+ gt_rule = self if lt_rule != self else rule
+ return IntervalRule(self.attr_name, gt_rule, lt_rule)
+
+ def __str__(self):
+ return '%s %s %.3f' % (
+ self.attr_name, '>' if self.sign else '≤', self.value)
+
+
+class IntervalRule(Rule):
+ """Interval rule class for ranges of continuous values.
+
+ Parameters
+ ----------
+ attr_name : str
+ left_rule : ContinuousRule
+ The smaller (left) part of the interval.
+ right_rule : ContinuousRule
+ The larger (right) part of the interval.
+
+ Examples
+ --------
+ 1 ≤ x < 3
+ >>> rule = IntervalRule('Rule',
+ >>> ContinuousRule('Rule', True, 1, inclusive=True),
+ >>> ContinuousRule('Rule', False, 3))
+
+ Notes
+ -----
+ - Currently, only cases which appear in classification and regression
+ trees are implemented. An interval can not be made up of two parts
+ (e.g. (-∞, -1) ∪ (1, ∞)).
+
+ """
+
+ def __init__(self, attr_name, left_rule, right_rule):
+ if not isinstance(left_rule, ContinuousRule):
+ raise AttributeError(
+ 'The left rule must be an instance of the `ContinuousRule` '
+ 'class.')
+ if not isinstance(right_rule, ContinuousRule):
+ raise AttributeError(
+ 'The right rule must be an instance of the `ContinuousRule` '
+ 'class.')
+
+ self.attr_name = attr_name
+ self.left_rule = left_rule
+ self.right_rule = right_rule
+
+ def merge_with(self, rule):
+ if isinstance(rule, ContinuousRule):
+ if rule.sign:
+ return IntervalRule(
+ self.attr_name, self.left_rule.merge_with(rule),
+ self.right_rule)
+ else:
+ return IntervalRule(
+ self.attr_name, self.left_rule,
+ self.right_rule.merge_with(rule))
+
+ elif isinstance(rule, IntervalRule):
+ return IntervalRule(
+ self.attr_name,
+ self.left_rule.merge_with(rule.left_rule),
+ self.right_rule.merge_with(rule.right_rule))
+
+ def __str__(self):
+ return '{} ∈ {}{:.3}, {:.3}{}'.format(
+ self.attr_name,
+ '[' if self.left_rule.inclusive else '(', self.left_rule.value,
+ self.right_rule.value, ']' if self.right_rule.inclusive else ')'
+ )
diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
new file mode 100644
index 00000000000..61886eab7ea
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
@@ -0,0 +1,302 @@
+from collections import OrderedDict
+from functools import lru_cache
+
+import numpy as np
+from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter
+
+from Orange.preprocess.transformation import Indicator
+from Orange.widgets.visualize.widgetutils.tree.rules import (
+ DiscreteRule,
+ ContinuousRule
+)
+
+
+class SklTreeAdapter(TreeAdapter):
+ """Sklear Tree Adapter.
+
+ An abstraction on top of the scikit learn classification tree.
+
+ Parameters
+ ----------
+ tree : sklearn.tree._tree.Tree
+ The raw sklearn classification tree.
+ domain : Orange.base.domain
+ The Orange domain that comes with the model.
+ adjust_weight : Callable, optional
+ If you want finer control over the weights of individual nodes you can
+ pass in a function that takes the existsing weight and modifies it.
+ The given function must have signture :: Number -> Number
+
+ """
+
+ ROOT_PARENT = -1
+ NO_CHILD = -1
+ FEATURE_UNDEFINED = -2
+
+ def __init__(self, tree, domain, adjust_weight=lambda x: x):
+ self._tree = tree
+ self._domain = domain
+ self._adjust_weight = adjust_weight
+
+ # clear memoized functions
+ self.weight.cache_clear()
+ self._adjusted_child_weight.cache_clear()
+ self.parent.cache_clear()
+
+ self._all_leaves = None
+
+ @lru_cache(maxsize=1024)
+ def weight(self, node):
+ return self._adjust_weight(self.num_samples(node)) / \
+ self._adjusted_child_weight(self.parent(node))
+
+ @lru_cache(maxsize=1024)
+ def _adjusted_child_weight(self, node):
+ """Helps when dealing with adjusted weights.
+
+ It is needed when dealing with non linear weights e.g. when calculating
+ the log weight, the sum of logs of all the children will not be equal
+ to the log of all the data instances.
+ A simple example: log(2) + log(2) != log(4)
+
+ Parameters
+ ----------
+ node : int
+ The label of the node.
+
+ Returns
+ -------
+ float
+ The sum of all of the weights of the children of a given node.
+
+ """
+ return sum(self._adjust_weight(self.num_samples(c))
+ for c in self.children(node)) \
+ if self.has_children(node) else 0
+
+ def num_samples(self, node):
+ return self._tree.n_node_samples[node]
+
+ @lru_cache(maxsize=1024)
+ def parent(self, node):
+ for children in (self._tree.children_left, self._tree.children_right):
+ try:
+ return (children == node).nonzero()[0][0]
+ except IndexError:
+ continue
+ return self.ROOT_PARENT
+
+ def has_children(self, node):
+ return self._tree.children_left[node] != self.NO_CHILD \
+ or self._tree.children_right[node] != self.NO_CHILD
+
+ def children(self, node):
+ if self.has_children(node):
+ return self.__left_child(node), self.__right_child(node)
+ return ()
+
+ def __left_child(self, node):
+ return self._tree.children_left[node]
+
+ def __right_child(self, node):
+ return self._tree.children_right[node]
+
+ def get_distribution(self, node):
+ return self._tree.value[node]
+
+ def get_impurity(self, node):
+ return self._tree.impurity[node]
+
+ @property
+ def max_depth(self):
+ return self._tree.max_depth
+
+ @property
+ def num_nodes(self):
+ return self._tree.node_count
+
+ @property
+ def root(self):
+ return 0
+
+ @property
+ def domain(self):
+ return self._domain
+
+ @lru_cache(maxsize=1024)
+ def rules(self, node):
+ if node != self.root:
+ parent = self.parent(node)
+ # Convert the parent list of rules into an ordered dict
+ pr = OrderedDict([(r.attr_name, r) for r in self.rules(parent)])
+
+ parent_attr = self.attribute(parent)
+ # Get the parent attribute type
+ parent_attr_cv = parent_attr.compute_value
+
+ is_left_child = self.__left_child(parent) == node
+
+ # The parent split variable is discrete
+ if isinstance(parent_attr_cv, Indicator) and \
+ hasattr(parent_attr_cv.variable, 'values'):
+ values = parent_attr_cv.variable.values
+ attr_name = parent_attr_cv.variable.name
+ eq = not is_left_child * (len(values) != 2)
+ value = values[abs(parent_attr_cv.value -
+ is_left_child * (len(values) == 2))]
+ new_rule = DiscreteRule(attr_name, eq, value)
+ # Since discrete variables should appear in their own lines
+ # they must not be merged, so the dict key is set with the
+ # value, so the same keys can exist with different values
+ # e.g. #legs ≠ 2 and #legs ≠ 4
+ attr_name = attr_name + '_' + value
+ # The parent split variable is continuous
+ else:
+ attr_name = parent_attr.name
+ sign = not is_left_child
+ value = self._tree.threshold[self.parent(node)]
+ new_rule = ContinuousRule(attr_name, sign, value,
+ inclusive=is_left_child)
+
+ # Check if a rule with that attribute exists
+ if attr_name in pr:
+ pr[attr_name] = pr[attr_name].merge_with(new_rule)
+ pr.move_to_end(attr_name)
+ else:
+ pr[attr_name] = new_rule
+
+ return list(pr.values())
+ else:
+ return []
+
+ def attribute(self, node):
+ feature_idx = self.splitting_attribute(node)
+ if feature_idx != self.FEATURE_UNDEFINED:
+ return self.domain.attributes[self.splitting_attribute(node)]
+
+ def splitting_attribute(self, node):
+ return self._tree.feature[node]
+
+ @lru_cache(maxsize=1024)
+ def leaves(self, node):
+ start, stop = self._subnode_range(node)
+ if start == stop:
+ # leaf
+ return np.array([node], dtype=int)
+ else:
+ is_leaf = self._tree.children_left[start:stop] == self.NO_CHILD
+ assert np.flatnonzero(is_leaf).size > 0
+ return start + np.flatnonzero(is_leaf)
+
+ def _subnode_range(self, node):
+ """
+ Get the range of indices where there are subnodes of the given node.
+
+ See Also
+ --------
+ Orange.widgets.classify.owclassificationtreegraph.OWTreeGraph
+ """
+
+ def find_largest_idx(n):
+ """It is necessary to locate the node with the largest index in the
+ children in order to get a good range. This is necessary with trees
+ that are not right aligned, which can happen when visualising
+ random forest trees."""
+ if self._tree.children_left[n] == self.NO_CHILD:
+ return n
+
+ l_node = find_largest_idx(self._tree.children_left[n])
+ r_node = find_largest_idx(self._tree.children_right[n])
+
+ return l_node if l_node > r_node else r_node
+
+ right = left = node
+ if self._tree.children_left[left] == self.NO_CHILD:
+ assert self._tree.children_right[node] == self.NO_CHILD
+ return node, node
+ else:
+ left = self._tree.children_left[left]
+ right = find_largest_idx(right)
+
+ return left, right + 1
+
+ def get_samples_in_leaves(self, data):
+ """Get an array of instance indices that belong to each leaf.
+
+ For a given dataset X, separate the instances out into an array, so
+ they are grouped together based on what leaf they belong to.
+
+ Examples
+ --------
+ Given a tree with two leaf nodes ( A <- R -> B ) and the dataset X =
+ [ 10, 20, 30, 40, 50, 60 ], where 10, 20 and 40 belong to leaf A, and
+ the rest to leaf B, the following structure will be returned (where
+ array is the numpy array):
+ [array([ 0, 1, 3 ]), array([ 2, 4, 5 ])]
+
+ The first array represents the indices of the values that belong to the
+ first leaft, so calling X[ 0, 1, 3 ] = [ 10, 20, 40 ]
+
+ Parameters
+ ----------
+ data
+ A matrix containing the data instances.
+
+ Returns
+ -------
+ np.array
+ The indices of instances belonging to a given leaf.
+
+ """
+
+ def assign(node_id, indices):
+ if self._tree.children_left[node_id] == self.NO_CHILD:
+ return [indices]
+ else:
+ feature_idx = self._tree.feature[node_id]
+ thresh = self._tree.threshold[node_id]
+
+ column = data[indices, feature_idx]
+ leftmask = column <= thresh
+ leftind = assign(self._tree.children_left[node_id],
+ indices[leftmask])
+ rightind = assign(self._tree.children_right[node_id],
+ indices[~leftmask])
+ return list.__iadd__(leftind, rightind)
+
+ # TODO this kind of cache can lead to all sorts of problems, but numpy
+ # arrays are unhashable, and this gives huge performance boosts
+ # also this would only become a problem if the function required to
+ # handle multiple datasets, which it doesn't, it just deals with the
+ # one the classification tree was fit to.
+ if self._all_leaves is not None:
+ return self._all_leaves
+
+ n, _ = data.shape
+
+ items = np.arange(n, dtype=int)
+ leaf_indices = assign(0, items)
+ self._all_leaves = leaf_indices
+ return leaf_indices
+
+ def get_instances_in_nodes(self, dataset, nodes):
+ if not isinstance(nodes, (list, tuple)):
+ nodes = [nodes]
+
+ node_leaves = [self.leaves(n.label) for n in nodes]
+ if len(node_leaves) > 0:
+ # get the leaves of the selected tree node
+ node_leaves = np.unique(np.hstack(node_leaves))
+
+ all_leaves = self.leaves(self.root)
+
+ indices = np.searchsorted(all_leaves, node_leaves, side='left')
+ # all the leaf samples for each leaf
+ leaf_samples = self.get_samples_in_leaves(dataset.X)
+ # filter out the leaf samples array that are not selected
+ leaf_samples = [leaf_samples[i] for i in indices]
+ indices = np.hstack(leaf_samples)
+ else:
+ indices = []
+
+ return dataset[indices] if len(indices) else None
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py b/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
new file mode 100644
index 00000000000..a6a4a86e0e4
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
@@ -0,0 +1,157 @@
+import unittest
+
+from Orange.widgets.visualize.widgetutils.tree.rules import (
+ ContinuousRule,
+ IntervalRule,
+)
+
+
+class TestRules(unittest.TestCase):
+ # CONTINUOUS RULES
+ def test_merging_two_gt_continuous_rules(self):
+ """Merging `x > 1` and `x > 2` should produce `x > 2`."""
+ rule1 = ContinuousRule('Rule', True, 1)
+ rule2 = ContinuousRule('Rule', True, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertEqual(new_rule.value, 2)
+
+ def test_merging_gt_with_gte_continuous_rule(self):
+ """Merging `x > 1` and `x ≥ 1` should produce `x > 1`."""
+ rule1 = ContinuousRule('Rule', True, 1, inclusive=True)
+ rule2 = ContinuousRule('Rule', True, 1, inclusive=False)
+ new_rule = rule1.merge_with(rule2)
+ self.assertEqual(new_rule.inclusive, False)
+
+ def test_merging_two_lt_continuous_rules(self):
+ """Merging `x < 1` and `x < 2` should produce `x < 1`."""
+ rule1 = ContinuousRule('Rule', False, 1)
+ rule2 = ContinuousRule('Rule', False, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertEqual(new_rule.value, 1)
+
+ def test_merging_lt_with_lte_rule(self):
+ """Merging `x < 1` and `x ≤ 1` should produce `x < 1`."""
+ rule1 = ContinuousRule('Rule', False, 1, inclusive=True)
+ rule2 = ContinuousRule('Rule', False, 1, inclusive=False)
+ new_rule = rule1.merge_with(rule2)
+ self.assertEqual(new_rule.inclusive, False)
+
+ def test_merging_lt_with_gt_continuous_rules(self):
+ """Merging `x > 1` and `x < 2` should produce `1 < x < 2`."""
+ rule1 = ContinuousRule('Rule', True, 1)
+ rule2 = ContinuousRule('Rule', False, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertIsInstance(new_rule, IntervalRule)
+ self.assertEquals(new_rule.left_rule, rule1)
+ self.assertEquals(new_rule.right_rule, rule2)
+
+ # INTERVAL RULES
+ def test_merging_interval_rule_with_smaller_continuous_rule(self):
+ """Merging `1 < x < 2` and `x < 3` should produce `1 < x < 2`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 1),
+ ContinuousRule('Rule', False, 2))
+ rule2 = ContinuousRule('Rule', False, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertIsInstance(new_rule, IntervalRule)
+ self.assertEquals(new_rule.right_rule.value, 2)
+
+ def test_merging_interval_rule_with_larger_continuous_rule(self):
+ """Merging `1 < x < 2` and `x < 3` should produce `1 < x < 2`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 1),
+ ContinuousRule('Rule', False, 2))
+ rule2 = ContinuousRule('Rule', False, 3)
+ new_rule = rule1.merge_with(rule2)
+ self.assertIsInstance(new_rule, IntervalRule)
+ self.assertEquals(new_rule.left_rule.value, 1)
+
+ def test_merging_interval_rule_with_larger_lt_continuous_rule(self):
+ """Merging `0 < x < 3` and `x > 1` should produce `1 < x < 3`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 0),
+ ContinuousRule('Rule', False, 3))
+ rule2 = ContinuousRule('Rule', True, 1)
+ new_rule = rule1.merge_with(rule2)
+ self.assertIsInstance(new_rule, IntervalRule)
+ self.assertEquals(new_rule.left_rule.value, 1)
+
+ def test_merging_interval_rule_with_smaller_gt_continuous_rule(self):
+ """Merging `0 < x < 3` and `x < 2` should produce `0 < x < 2`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 0),
+ ContinuousRule('Rule', False, 3))
+ rule2 = ContinuousRule('Rule', False, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertIsInstance(new_rule, IntervalRule)
+ self.assertEquals(new_rule.right_rule.value, 2)
+
+ def test_merging_interval_rules_with_smaller_lt_component(self):
+ """Merging `1 < x < 2` and `0 < x < 2` should produce `1 < x < 2`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 1),
+ ContinuousRule('Rule', False, 2))
+ rule2 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 0),
+ ContinuousRule('Rule', False, 2))
+ new_rule = rule1.merge_with(rule2)
+ self.assertEquals(new_rule.left_rule.value, 1)
+ self.assertEquals(new_rule.right_rule.value, 2)
+
+ def test_merging_interval_rules_with_larger_lt_component(self):
+ """Merging `0 < x < 4` and `1 < x < 4` should produce `1 < x < 4`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 0),
+ ContinuousRule('Rule', False, 4))
+ rule2 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 1),
+ ContinuousRule('Rule', False, 4))
+ new_rule = rule1.merge_with(rule2)
+ self.assertEquals(new_rule.left_rule.value, 1)
+ self.assertEquals(new_rule.right_rule.value, 4)
+
+ def test_merging_interval_rules_generally(self):
+ """Merging `0 < x < 4` and `2 < x < 6` should produce `2 < x < 4`."""
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 0),
+ ContinuousRule('Rule', False, 4))
+ rule2 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 2),
+ ContinuousRule('Rule', False, 6))
+ new_rule = rule1.merge_with(rule2)
+ self.assertEquals(new_rule.left_rule.value, 2)
+ self.assertEquals(new_rule.right_rule.value, 4)
+
+ # ALL RULES
+ def test_merge_commutativity_on_continuous_rules(self):
+ rule1 = ContinuousRule('Rule1', True, 1)
+ rule2 = ContinuousRule('Rule1', True, 2)
+ new_rule1 = rule1.merge_with(rule2)
+ new_rule2 = rule2.merge_with(rule1)
+ self.assertEqual(new_rule1.value, new_rule2.value)
+
+ def test_merge_commutativity_on_interval_rules(self):
+ rule1 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 0),
+ ContinuousRule('Rule', False, 4))
+ rule2 = IntervalRule('Rule',
+ ContinuousRule('Rule', True, 2),
+ ContinuousRule('Rule', False, 6))
+ new_rule1 = rule1.merge_with(rule2)
+ new_rule2 = rule2.merge_with(rule1)
+ self.assertEquals(new_rule1.left_rule.value,
+ new_rule2.left_rule.value)
+ self.assertEquals(new_rule1.right_rule.value,
+ new_rule2.right_rule.value)
+
+ def test_merge_keeps_sign_on_continuous_rules(self):
+ rule1 = ContinuousRule('Rule1', True, 1)
+ rule2 = ContinuousRule('Rule1', True, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertEquals(new_rule.sign, True)
+
+ def test_merge_keeps_attr_name_on_continuous_rules(self):
+ rule1 = ContinuousRule('Rule1', True, 1)
+ rule2 = ContinuousRule('Rule1', True, 2)
+ new_rule = rule1.merge_with(rule2)
+ self.assertEquals(new_rule.attr_name, 'Rule1')
diff --git a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py
new file mode 100644
index 00000000000..7bf07368c5f
--- /dev/null
+++ b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py
@@ -0,0 +1,274 @@
+from abc import ABCMeta, abstractmethod
+
+
+class TreeAdapter(metaclass=ABCMeta):
+ """Base class for tree representation.
+
+ Any subclass should implement the methods listed in this base class. Note
+ that some simple methods do not need to reimplemented e.g. is_leaf since
+ it that is the opposite of has_children.
+
+ """
+
+ ROOT_PARENT = -1
+ NO_CHILD = -1
+ FEATURE_UNDEFINED = -2
+
+ @abstractmethod
+ def weight(self, node):
+ """Get the weight of the given node.
+
+ The weights of the children always sum up to 1.
+
+ Parameters
+ ----------
+ node : object
+ The label of the node.
+
+ Returns
+ -------
+ float
+ The weight of the node relative to its siblings.
+
+ """
+ pass
+
+ @abstractmethod
+ def num_samples(self, node):
+ """Get the number of samples that a given node contains.
+
+ Parameters
+ ----------
+ node : object
+ A unique identifier of a node.
+
+ Returns
+ -------
+ int
+
+ """
+ pass
+
+ @abstractmethod
+ def parent(self, node):
+ """Get the parent of a given node. Return -1 if the node is the root.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ object
+
+ """
+ pass
+
+ @abstractmethod
+ def has_children(self, node):
+ """Check if the given node has any children.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ bool
+
+ """
+ pass
+
+ def is_leaf(self, node):
+ """Check if the given node is a leaf node.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ object
+
+ """
+ return not self.has_children(node)
+
+ @abstractmethod
+ def children(self, node):
+ """Get all the children of a given node.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ Iterable[object]
+ A iterable object containing the labels of the child nodes.
+
+ """
+ pass
+
+ @abstractmethod
+ def get_distribution(self, node):
+ """Get the distribution of types for a given node.
+
+ This may be the number of nodes that belong to each different classe in
+ a node.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ Iterable[int, ...]
+ The return type is an iterable with as many fields as there are
+ different classes in the given node. The values of the fields are
+ the number of nodes that belong to a given class inside the node.
+
+ """
+ pass
+
+ @abstractmethod
+ def get_impurity(self, node):
+ """Get the impurity of a given node.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ object
+
+ """
+ pass
+
+ @abstractmethod
+ def rules(self, node):
+ """Get a list of rules that define the given node.
+
+ Parameters
+ ----------
+ node : object
+
+ Returns
+ -------
+ Iterable[Rule]
+ A list of Rule objects, can be of any type.
+
+ """
+ pass
+
+ @abstractmethod
+ def attribute(self, node):
+ """Get the attribute that splits the given tree.
+
+ Parameters
+ ----------
+ node
+
+ Returns
+ -------
+
+ """
+ pass
+
+ def is_root(self, node):
+ """Check if a given node is the root node.
+
+ Parameters
+ ----------
+ node
+
+ Returns
+ -------
+
+ """
+ return node == self.root
+
+ @abstractmethod
+ def leaves(self, node):
+ """Get all the leavse that belong to the subtree of a given node.
+
+ Parameters
+ ----------
+ node
+
+ Returns
+ -------
+
+ """
+ pass
+
+ @abstractmethod
+ def get_instances_in_nodes(self, dataset, nodes):
+ """Get all the instances belonging to a set of nodes for a given
+ dataset.
+
+ Parameters
+ ----------
+ dataset : Table
+ A Orange Table dataset.
+ nodes : iterable[TreeNode]
+ A list of tree nodes for which we want the instances.
+
+ Returns
+ -------
+
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def max_depth(self):
+ """Get the maximum depth that the tree reaches.
+
+ Returns
+ -------
+ int
+
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def num_nodes(self):
+ """Get the total number of nodes that the tree contains.
+
+ This does not mean the number of samples inside the entire tree, just
+ the number of nodes.
+
+ Returns
+ -------
+ int
+
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def root(self):
+ """Get the label of the root node.
+
+ Returns
+ -------
+ object
+
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def domain(self):
+ """Get the domain of the given tree.
+
+ The domain contains information about the classes what the tree
+ represents.
+
+ Returns
+ -------
+
+ """
+ pass
From df43036337d7a8ee6265bc8fb523f9f684aa5124 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Tue, 12 Jul 2016 10:31:30 +0200
Subject: [PATCH 02/12] Pythagorean tree: Add memoize_method to avoid memory
leaks
---
Orange/misc/cache.py | 41 +++++++++++++++++++
.../widgetutils/tree/skltreeadapter.py | 18 ++++----
2 files changed, 48 insertions(+), 11 deletions(-)
diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py
index bf5f295b7d9..2dc9f39751d 100644
--- a/Orange/misc/cache.py
+++ b/Orange/misc/cache.py
@@ -1,3 +1,7 @@
+from functools import wraps, lru_cache
+import weakref
+
+
def single_cache(f):
last_args = ()
last_kwargs = set()
@@ -14,3 +18,40 @@ def cached(*args, **kwargs):
return last_result
return cached
+
+
+def memoize_method(*lru_args, **lru_kwargs):
+ """Memoize methods without keeping reference to `self`.
+
+ Parameters
+ ----------
+ lru_args
+ lru_kwargs
+
+ Returns
+ -------
+
+ See Also
+ --------
+ https://stackoverflow.com/questions/33672412/python-functools-lru-cache-with-class-methods-release-object
+
+ """
+ def _decorator(func):
+
+ @wraps(func)
+ def _wrapped_func(self, *args, **kwargs):
+ self_weak = weakref.ref(self)
+ # We're storing the wrapped method inside the instance. If we had
+ # a strong reference to self the instance would never die.
+
+ @wraps(func)
+ @lru_cache(*lru_args, **lru_kwargs)
+ def _cached_method(*args, **kwargs):
+ return func(self_weak(), *args, **kwargs)
+
+ setattr(self, func.__name__, _cached_method)
+ return _cached_method(*args, **kwargs)
+
+ return _wrapped_func
+
+ return _decorator
diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
index 61886eab7ea..307fad83d1c 100644
--- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
+++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
@@ -1,7 +1,8 @@
from collections import OrderedDict
-from functools import lru_cache
import numpy as np
+
+from Orange.misc.cache import memoize_method
from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter
from Orange.preprocess.transformation import Indicator
@@ -38,19 +39,14 @@ def __init__(self, tree, domain, adjust_weight=lambda x: x):
self._domain = domain
self._adjust_weight = adjust_weight
- # clear memoized functions
- self.weight.cache_clear()
- self._adjusted_child_weight.cache_clear()
- self.parent.cache_clear()
-
self._all_leaves = None
- @lru_cache(maxsize=1024)
+ @memoize_method(maxsize=1024)
def weight(self, node):
return self._adjust_weight(self.num_samples(node)) / \
self._adjusted_child_weight(self.parent(node))
- @lru_cache(maxsize=1024)
+ @memoize_method(maxsize=1024)
def _adjusted_child_weight(self, node):
"""Helps when dealing with adjusted weights.
@@ -77,7 +73,7 @@ def _adjusted_child_weight(self, node):
def num_samples(self, node):
return self._tree.n_node_samples[node]
- @lru_cache(maxsize=1024)
+ @memoize_method(maxsize=1024)
def parent(self, node):
for children in (self._tree.children_left, self._tree.children_right):
try:
@@ -123,7 +119,7 @@ def root(self):
def domain(self):
return self._domain
- @lru_cache(maxsize=1024)
+ @memoize_method(maxsize=1024)
def rules(self, node):
if node != self.root:
parent = self.parent(node)
@@ -177,7 +173,7 @@ def attribute(self, node):
def splitting_attribute(self, node):
return self._tree.feature[node]
- @lru_cache(maxsize=1024)
+ @memoize_method(maxsize=1024)
def leaves(self, node):
start, stop = self._subnode_range(node)
if start == stop:
From d860fcd4142b8bd5867043abf4eccb077e17cac5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Tue, 12 Jul 2016 10:39:19 +0200
Subject: [PATCH 03/12] Pythagorean tree: Make pylint happier
---
Orange/widgets/visualize/owpythagorastree.py | 6 +--
.../widgets/visualize/pythagorastreeviewer.py | 4 +-
.../visualize/widgetutils/common/view.py | 44 ++++++++--------
.../visualize/widgetutils/tree/rules.py | 52 ++++++++++---------
.../widgetutils/tree/skltreeadapter.py | 4 +-
.../widgetutils/tree/tests/test_rules.py | 38 +++++++-------
6 files changed, 74 insertions(+), 74 deletions(-)
diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py
index 2ce4972e22a..4da683dcdca 100644
--- a/Orange/widgets/visualize/owpythagorastree.py
+++ b/Orange/widgets/visualize/owpythagorastree.py
@@ -294,10 +294,8 @@ def _update_legend_visibility(self):
def _update_log_scale_slider(self):
"""On calc method combo box changed."""
- if self.SIZE_CALCULATION[self.size_calc_idx][0] == 'Logarithmic':
- self.log_scale_box.parent().setEnabled(True)
- else:
- self.log_scale_box.parent().setEnabled(False)
+ self.log_scale_box.parent().setEnabled(
+ self.SIZE_CALCULATION[self.size_calc_idx][0] == 'Logarithmic')
# MODEL REMOVED CONTROL ELEMENTS CLEAR METHODS
def _clear_info_box(self):
diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py
index 42e79119ac2..6fa6ef0f2d8 100644
--- a/Orange/widgets/visualize/pythagorastreeviewer.py
+++ b/Orange/widgets/visualize/pythagorastreeviewer.py
@@ -442,7 +442,7 @@ def __init__(self, tree_node, parent=None, **kwargs):
self.timer.setSingleShot(True)
- def hoverEnterEvent(self, ev):
+ def hoverEnterEvent(self, event):
self.timer.stop()
def fnc(graphics_item):
@@ -466,7 +466,7 @@ def other_fnc(graphics_item):
self._propagate_z_values(self, fnc, other_fnc)
- def hoverLeaveEvent(self, ev):
+ def hoverLeaveEvent(self, event):
def fnc(graphics_item):
# No need to set opacity in this branch since it was just selected
diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/widgetutils/common/view.py
index e080828f757..1803c9512f2 100644
--- a/Orange/widgets/visualize/widgetutils/common/view.py
+++ b/Orange/widgets/visualize/widgetutils/common/view.py
@@ -44,35 +44,35 @@ def __init__(self, scene, padding=(0, 0), **kwargs):
super().__init__(scene, **kwargs)
- def resizeEvent(self, ev):
- super().resizeEvent(ev)
+ def resizeEvent(self, event):
+ super().resizeEvent(event)
self.__needs_to_recalculate_initial = True
- def wheelEvent(self, ev):
- self.__handle_zoom(ev.delta())
- super().wheelEvent(ev)
+ def wheelEvent(self, event):
+ self.__handle_zoom(event.delta())
+ super().wheelEvent(event)
- def mousePressEvent(self, ev):
+ def mousePressEvent(self, event):
# right click resets the zoom factor
- if ev.button() == Qt.RightButton:
+ if event.button() == Qt.RightButton:
self.reset_zoom()
- super().mousePressEvent(ev)
+ super().mousePressEvent(event)
- def keyPressEvent(self, ev):
- if ev.key() == Qt.Key_Plus:
+ def keyPressEvent(self, event):
+ if event.key() == Qt.Key_Plus:
self.__handle_zoom(1)
- elif ev.key() == Qt.Key_Minus:
+ elif event.key() == Qt.Key_Minus:
self.__handle_zoom(-1)
- super().keyPressEvent(ev)
+ super().keyPressEvent(event)
def __set_padding(self, padding):
# Allow for multiple formats of padding for convenience
if isinstance(padding, int):
- padding = list(repeat(padding, 4))
+ padding = tuple(repeat(padding, 4))
elif isinstance(padding, list) or isinstance(padding, tuple):
if len(padding) == 2:
- padding = (*padding, *padding)
+ padding = tuple(padding * 2)
else:
padding = 0, 0, 0, 0
@@ -99,8 +99,8 @@ def __handle_zoom(self, direction):
def __zooming_out(direction):
return direction < 0
- def __zooming_in(self, ev):
- return not self.__zooming_out(ev)
+ def __zooming_in(self, event):
+ return not self.__zooming_out(event)
def __reset_zoomout_limit(self):
self.__zoomout_limit_reached = False
@@ -166,15 +166,15 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setDragMode(QtGui.QGraphicsView.ScrollHandDrag)
- def enterEvent(self, ev):
+ def enterEvent(self, event):
self.viewport().setCursor(Qt.ArrowCursor)
- super().enterEvent(ev)
+ super().enterEvent(event)
- def mouseReleaseEvent(self, ev):
- super().mouseReleaseEvent(ev)
+ def mouseReleaseEvent(self, event):
+ super().mouseReleaseEvent(event)
self.viewport().setCursor(Qt.ArrowCursor)
class PreventDefaultWheelEvent(QtGui.QGraphicsView):
- def wheelEvent(self, ev):
- ev.accept()
+ def wheelEvent(self, event):
+ event.accept()
diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py
index c113e096f55..587226a23d1 100644
--- a/Orange/widgets/visualize/widgetutils/tree/rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/rules.py
@@ -1,3 +1,6 @@
+import warnings
+
+
class Rule:
"""The base Rule class for tree rules."""
@@ -13,7 +16,7 @@ def merge_with(self, rule):
Rule
"""
- raise NotImplemented()
+ raise NotImplementedError()
class DiscreteRule(Rule):
@@ -30,32 +33,31 @@ class DiscreteRule(Rule):
--------
Age = 30
>>> rule = DiscreteRule('age', True, 30)
+
Name ≠ John
>>> rule = DiscreteRule('name', False, 'John')
Notes
-----
- - Merging discrete rules is currently not implemented, the new rule is
- simply returned and a warning is printed to stderr.
+ .. Note:: Merging discrete rules is currently not implemented, the new rule
+ is simply returned and a warning is issued.
"""
def __init__(self, attr_name, eq, value):
self.attr_name = attr_name
- self.sign = eq
+ self.eq = eq
self.value = value
def merge_with(self, rule):
# It does not make sense to merge discrete rules, since they can only
# be eq or not eq.
- from sys import stderr
- print('WARNING: Merged two discrete rules `%s` and `%s`'
- % (self, rule), file=stderr)
+ warnings.warn('Merged two discrete rules `%s` and `%s`' % (self, rule))
return rule
def __str__(self):
return '{} {} {}'.format(
- self.attr_name, '=' if self.sign else '≠', self.value)
+ self.attr_name, '=' if self.eq else '≠', self.value)
class ContinuousRule(Rule):
@@ -75,45 +77,45 @@ class ContinuousRule(Rule):
--------
x ≤ 30
>>> rule = ContinuousRule('age', False, 30, inclusive=True)
+
x > 30
>>> rule = ContinuousRule('age', True, 30)
Notes
-----
- - Continuous rules can currently only be merged with other continuous
- rules.
+ .. Note:: Continuous rules can currently only be merged with other
+ continuous rules.
"""
def __init__(self, attr_name, gt, value, inclusive=False):
self.attr_name = attr_name
- self.sign = gt
+ self.gt = gt
self.value = value
self.inclusive = inclusive
def merge_with(self, rule):
if not isinstance(rule, ContinuousRule):
- raise NotImplemented('Continuous rules can currently only be '
- 'merged with other continuous rules')
+ raise NotImplementedError('Continuous rules can currently only be '
+ 'merged with other continuous rules')
# Handle when both have same sign
- if self.sign == rule.sign:
+ if self.gt == rule.gt:
# When both are GT
- if self.sign is True:
- larger = self.value if self.value > rule.value else rule.value
- return ContinuousRule(self.attr_name, self.sign, larger)
+ if self.gt is True:
+ larger = max(self.value, rule.value)
+ return ContinuousRule(self.attr_name, self.gt, larger)
# When both are LT
else:
smaller = self.value if self.value < rule.value else rule.value
- return ContinuousRule(self.attr_name, self.sign, smaller)
+ return ContinuousRule(self.attr_name, self.gt, smaller)
# When they have different signs we need to return an interval rule
else:
- lt_rule = self if self.sign is False else rule
- gt_rule = self if lt_rule != self else rule
+ lt_rule, gt_rule = (rule, self) if self.gt else (self, rule)
return IntervalRule(self.attr_name, gt_rule, lt_rule)
def __str__(self):
return '%s %s %.3f' % (
- self.attr_name, '>' if self.sign else '≤', self.value)
+ self.attr_name, '>' if self.gt else '≤', self.value)
class IntervalRule(Rule):
@@ -136,9 +138,9 @@ class IntervalRule(Rule):
Notes
-----
- - Currently, only cases which appear in classification and regression
- trees are implemented. An interval can not be made up of two parts
- (e.g. (-∞, -1) ∪ (1, ∞)).
+ .. Note:: Currently, only cases which appear in classification and
+ regression trees are implemented. An interval can not be made up of two
+ parts (e.g. (-∞, -1) ∪ (1, ∞)).
"""
@@ -158,7 +160,7 @@ def __init__(self, attr_name, left_rule, right_rule):
def merge_with(self, rule):
if isinstance(rule, ContinuousRule):
- if rule.sign:
+ if rule.gt:
return IntervalRule(
self.attr_name, self.left_rule.merge_with(rule),
self.right_rule)
diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
index 307fad83d1c..759542cea18 100644
--- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
+++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
@@ -204,7 +204,7 @@ def find_largest_idx(n):
l_node = find_largest_idx(self._tree.children_left[n])
r_node = find_largest_idx(self._tree.children_right[n])
- return l_node if l_node > r_node else r_node
+ return max(l_node, r_node)
right = left = node
if self._tree.children_left[left] == self.NO_CHILD:
@@ -286,7 +286,7 @@ def get_instances_in_nodes(self, dataset, nodes):
all_leaves = self.leaves(self.root)
- indices = np.searchsorted(all_leaves, node_leaves, side='left')
+ indices = np.searchsorted(all_leaves, node_leaves)
# all the leaf samples for each leaf
leaf_samples = self.get_samples_in_leaves(dataset.X)
# filter out the leaf samples array that are not selected
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
index a6a4a86e0e4..0a8633b26f6 100644
--- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
@@ -42,8 +42,8 @@ def test_merging_lt_with_gt_continuous_rules(self):
rule2 = ContinuousRule('Rule', False, 2)
new_rule = rule1.merge_with(rule2)
self.assertIsInstance(new_rule, IntervalRule)
- self.assertEquals(new_rule.left_rule, rule1)
- self.assertEquals(new_rule.right_rule, rule2)
+ self.assertEqual(new_rule.left_rule, rule1)
+ self.assertEqual(new_rule.right_rule, rule2)
# INTERVAL RULES
def test_merging_interval_rule_with_smaller_continuous_rule(self):
@@ -54,7 +54,7 @@ def test_merging_interval_rule_with_smaller_continuous_rule(self):
rule2 = ContinuousRule('Rule', False, 2)
new_rule = rule1.merge_with(rule2)
self.assertIsInstance(new_rule, IntervalRule)
- self.assertEquals(new_rule.right_rule.value, 2)
+ self.assertEqual(new_rule.right_rule.value, 2)
def test_merging_interval_rule_with_larger_continuous_rule(self):
"""Merging `1 < x < 2` and `x < 3` should produce `1 < x < 2`."""
@@ -64,7 +64,7 @@ def test_merging_interval_rule_with_larger_continuous_rule(self):
rule2 = ContinuousRule('Rule', False, 3)
new_rule = rule1.merge_with(rule2)
self.assertIsInstance(new_rule, IntervalRule)
- self.assertEquals(new_rule.left_rule.value, 1)
+ self.assertEqual(new_rule.left_rule.value, 1)
def test_merging_interval_rule_with_larger_lt_continuous_rule(self):
"""Merging `0 < x < 3` and `x > 1` should produce `1 < x < 3`."""
@@ -74,7 +74,7 @@ def test_merging_interval_rule_with_larger_lt_continuous_rule(self):
rule2 = ContinuousRule('Rule', True, 1)
new_rule = rule1.merge_with(rule2)
self.assertIsInstance(new_rule, IntervalRule)
- self.assertEquals(new_rule.left_rule.value, 1)
+ self.assertEqual(new_rule.left_rule.value, 1)
def test_merging_interval_rule_with_smaller_gt_continuous_rule(self):
"""Merging `0 < x < 3` and `x < 2` should produce `0 < x < 2`."""
@@ -84,7 +84,7 @@ def test_merging_interval_rule_with_smaller_gt_continuous_rule(self):
rule2 = ContinuousRule('Rule', False, 2)
new_rule = rule1.merge_with(rule2)
self.assertIsInstance(new_rule, IntervalRule)
- self.assertEquals(new_rule.right_rule.value, 2)
+ self.assertEqual(new_rule.right_rule.value, 2)
def test_merging_interval_rules_with_smaller_lt_component(self):
"""Merging `1 < x < 2` and `0 < x < 2` should produce `1 < x < 2`."""
@@ -95,8 +95,8 @@ def test_merging_interval_rules_with_smaller_lt_component(self):
ContinuousRule('Rule', True, 0),
ContinuousRule('Rule', False, 2))
new_rule = rule1.merge_with(rule2)
- self.assertEquals(new_rule.left_rule.value, 1)
- self.assertEquals(new_rule.right_rule.value, 2)
+ self.assertEqual(new_rule.left_rule.value, 1)
+ self.assertEqual(new_rule.right_rule.value, 2)
def test_merging_interval_rules_with_larger_lt_component(self):
"""Merging `0 < x < 4` and `1 < x < 4` should produce `1 < x < 4`."""
@@ -107,8 +107,8 @@ def test_merging_interval_rules_with_larger_lt_component(self):
ContinuousRule('Rule', True, 1),
ContinuousRule('Rule', False, 4))
new_rule = rule1.merge_with(rule2)
- self.assertEquals(new_rule.left_rule.value, 1)
- self.assertEquals(new_rule.right_rule.value, 4)
+ self.assertEqual(new_rule.left_rule.value, 1)
+ self.assertEqual(new_rule.right_rule.value, 4)
def test_merging_interval_rules_generally(self):
"""Merging `0 < x < 4` and `2 < x < 6` should produce `2 < x < 4`."""
@@ -119,8 +119,8 @@ def test_merging_interval_rules_generally(self):
ContinuousRule('Rule', True, 2),
ContinuousRule('Rule', False, 6))
new_rule = rule1.merge_with(rule2)
- self.assertEquals(new_rule.left_rule.value, 2)
- self.assertEquals(new_rule.right_rule.value, 4)
+ self.assertEqual(new_rule.left_rule.value, 2)
+ self.assertEqual(new_rule.right_rule.value, 4)
# ALL RULES
def test_merge_commutativity_on_continuous_rules(self):
@@ -139,19 +139,19 @@ def test_merge_commutativity_on_interval_rules(self):
ContinuousRule('Rule', False, 6))
new_rule1 = rule1.merge_with(rule2)
new_rule2 = rule2.merge_with(rule1)
- self.assertEquals(new_rule1.left_rule.value,
- new_rule2.left_rule.value)
- self.assertEquals(new_rule1.right_rule.value,
- new_rule2.right_rule.value)
+ self.assertEqual(new_rule1.left_rule.value,
+ new_rule2.left_rule.value)
+ self.assertEqual(new_rule1.right_rule.value,
+ new_rule2.right_rule.value)
- def test_merge_keeps_sign_on_continuous_rules(self):
+ def test_merge_keeps_gt_on_continuous_rules(self):
rule1 = ContinuousRule('Rule1', True, 1)
rule2 = ContinuousRule('Rule1', True, 2)
new_rule = rule1.merge_with(rule2)
- self.assertEquals(new_rule.sign, True)
+ self.assertEqual(new_rule.gt, True)
def test_merge_keeps_attr_name_on_continuous_rules(self):
rule1 = ContinuousRule('Rule1', True, 1)
rule2 = ContinuousRule('Rule1', True, 2)
new_rule = rule1.merge_with(rule2)
- self.assertEquals(new_rule.attr_name, 'Rule1')
+ self.assertEqual(new_rule.attr_name, 'Rule1')
From 7ef0d70540c76b62b62e3da27d5fea06794792c3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Tue, 12 Jul 2016 12:07:13 +0200
Subject: [PATCH 04/12] Pythagorean tree: Add docstrings and reduce pylint
errors
---
Orange/misc/cache.py | 1 +
Orange/widgets/visualize/owpythagorastree.py | 18 +--
.../widgets/visualize/owpythagoreanforest.py | 10 ++
.../widgets/visualize/pythagorastreeviewer.py | 13 ++-
.../visualize/tests/test_owpythagorastree.py | 17 ++-
.../visualize/widgetutils/common/owgrid.py | 13 ++-
.../visualize/widgetutils/common/owlegend.py | 109 +++++++++++-------
.../visualize/widgetutils/common/scene.py | 5 +-
.../visualize/widgetutils/common/view.py | 28 ++++-
.../visualize/widgetutils/tree/rules.py | 47 +++++---
.../widgetutils/tree/skltreeadapter.py | 1 +
.../widgetutils/tree/tests/test_rules.py | 12 ++
.../visualize/widgetutils/tree/treeadapter.py | 1 +
13 files changed, 198 insertions(+), 77 deletions(-)
diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py
index 2dc9f39751d..6d2d3f3baa7 100644
--- a/Orange/misc/cache.py
+++ b/Orange/misc/cache.py
@@ -1,3 +1,4 @@
+"""Common caching methods, using `lru_cahce` sometimes has its downsides."""
from functools import wraps, lru_cache
import weakref
diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py
index 4da683dcdca..0ba5b48306c 100644
--- a/Orange/widgets/visualize/owpythagorastree.py
+++ b/Orange/widgets/visualize/owpythagorastree.py
@@ -397,8 +397,7 @@ def _classification_update_legend_colors(self):
else:
items = (
(self.target_class_combo.itemText(self.target_class_index),
- self.color_palette[self.target_class_index - 1]
- ),
+ self.color_palette[self.target_class_index - 1]),
('other', QtGui.QColor('#ffffff'))
)
self.legend = OWDiscreteLegend(items=items, **self.LEGEND_OPTIONS)
@@ -450,12 +449,12 @@ def _classification_get_tooltip(self, node):
return '
' \
+ text \
+ '{}/{} samples ({:2.3f}%)'.format(
- int(samples), total, ratio * 100) \
+ int(samples), total, ratio * 100) \
+ '
' \
+ ('Split by ' + splitting_attr.name
- if not self.tree_adapter.is_leaf(node.label) else '') \
+ if not self.tree_adapter.is_leaf(node.label) else '') \
+ ('
' if len(rules) and not self.tree_adapter.is_leaf(
- node.label) else '') \
+ node.label) else '') \
+ rules_str \
+ ''
@@ -556,12 +555,12 @@ def _regression_get_tooltip(self, node):
return '
Mean: {:2.3f}'.format(mean) \
+ ' Standard deviation: {:2.3f}'.format(std) \
+ ' {}/{} samples ({:2.3f}%)'.format(
- int(samples), total, ratio * 100) \
+ int(samples), total, ratio * 100) \
+ '
' \
+ ('Split by ' + splitting_attr.name
- if not self.tree_adapter.is_leaf(node.label) else '') \
+ if not self.tree_adapter.is_leaf(node.label) else '') \
+ ('
' if len(rules) and not self.tree_adapter.is_leaf(
- node.label) else '') \
+ node.label) else '') \
+ rules_str \
+ ''
@@ -572,10 +571,13 @@ class TreeGraphicsView(
AnchorableGraphicsView,
PreventDefaultWheelEvent
):
+ """QGraphicsView that contains all functionality we will use to display
+ tree."""
pass
class TreeGraphicsScene(UpdateItemsOnSelectGraphicsScene):
+ """QGraphicsScene that the tree uses."""
pass
diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py
index d007aa61eaf..e53b9c25c6c 100644
--- a/Orange/widgets/visualize/owpythagoreanforest.py
+++ b/Orange/widgets/visualize/owpythagoreanforest.py
@@ -1,3 +1,4 @@
+"""Pythagorean forest widget for visualizing random forests."""
from math import log, sqrt
import numpy as np
@@ -163,14 +164,17 @@ def clear(self):
# CONTROL AREA CALLBACKS
def max_depth_changed(self):
+ """When the max depth slider is changed."""
for tree in self.ptrees:
tree.set_depth_limit(self.depth_limit)
def target_colors_changed(self):
+ """When the target class or coloring method is changed."""
for tree in self.ptrees:
tree.target_class_has_changed()
def size_calc_changed(self):
+ """When the size calculation of the trees is changed."""
if self.model is not None:
self.forest_adapter = self._get_forest_adapter(self.model)
self.grid.clear()
@@ -181,6 +185,7 @@ def size_calc_changed(self):
self.max_depth_changed()
def zoom_changed(self):
+ """When we update the "Zoom" slider."""
for item in self.grid_items:
item.set_max_size(self._calculate_zoom(self.zoom))
@@ -385,10 +390,13 @@ def _color_stddev(self, adapter, tree_node):
class GridItem(SelectableGridItem, ZoomableGridItem):
+ """The grid item we will use in our grid."""
pass
class SklRandomForestAdapter:
+ """Take a `RandomForest` and wrap all the trees into the `TreeAdapter`
+ instances that Pythagorean trees use."""
def __init__(self, model, domain, adjust_weight=lambda x: x):
self._adapters = []
@@ -399,6 +407,7 @@ def __init__(self, model, domain, adjust_weight=lambda x: x):
self._adjust_weight = adjust_weight
def get_trees(self):
+ """Get the tree adapters in the random forest."""
if len(self._adapters) > 0:
return self._adapters
if len(self._trees) < 1:
@@ -412,4 +421,5 @@ def get_trees(self):
@property
def domain(self):
+ """Get the domain."""
return self._domain
diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py
index 6fa6ef0f2d8..41d4ac12d3e 100644
--- a/Orange/widgets/visualize/pythagorastreeviewer.py
+++ b/Orange/widgets/visualize/pythagorastreeviewer.py
@@ -21,8 +21,6 @@
from PyQt4 import QtCore, QtGui
from PyQt4.QtCore import Qt
-from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter
-
# z index range, increase if needed
Z_STEP = 5000000
@@ -38,6 +36,9 @@ class PythagorasTreeViewer(QtGui.QGraphicsWidget):
Examples
--------
+ >>> from Orange.widgets.visualize.widgetutils.tree.treeadapter import (
+ >>> TreeAdapter
+ >>> )
Pass tree through constructor.
>>> tree_view = PythagorasTreeViewer(parent=scene, adapter=tree_adapter)
@@ -64,7 +65,7 @@ class PythagorasTreeViewer(QtGui.QGraphicsWidget):
Notes
-----
- .. Note:: The class contains two clear methods: `clear` and `clear_tree`.
+ .. note:: The class contains two clear methods: `clear` and `clear_tree`.
Each has their own use.
`clear_tree` will clear out the tree and remove any graphics items.
`clear` will, on the other hand, clear everything, all settings
@@ -195,10 +196,12 @@ def _get_tooltip(self, *args):
return 'Tooltip'
def target_class_has_changed(self):
+ """When the target class has changed, perform appropriate updates."""
self._update_node_colors()
self._update_node_tooltips()
def tooltip_has_changed(self):
+ """When the tooltip should change, perform appropriate updates."""
self._update_node_tooltips()
def _update_node_colors(self):
@@ -512,7 +515,7 @@ def _propagate_to_parents(self, graphics_item, fnc, other_fnc):
self._propagate_to_parents(parent, fnc, other_fnc)
def selection_changed(self):
- # Handle selection changed
+ """Handle selection changed."""
self.any_selected = len(self.scene().selectedItems()) > 0
if self.any_selected:
if self.isSelected():
@@ -554,7 +557,7 @@ class TreeNode:
parent : TreeNode or object
The parent of the current node. In the case of root, an object
containing the root label of the tree adapter should be passed.
- children : tuple of TreeNode, optional
+ children : tuple of TreeNode, optional, default is empty tuple
All the children that belong to this node.
"""
diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py
index c0f7d05ed40..4015bef0a3e 100644
--- a/Orange/widgets/visualize/tests/test_owpythagorastree.py
+++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py
@@ -1,6 +1,6 @@
+"""Tests for the Pythagorean tree widget and associated classes."""
import math
import unittest
-import Orange.widgets
from Orange.widgets.visualize.pythagorastreeviewer import (
PythagorasTree,
@@ -10,10 +10,17 @@
class TestPythagorasTree(unittest.TestCase):
+ """Pythagorean tree testing, make sure calculating square positions works
+ properly.
+
+ Most of the data is non trivial since the rotations and translations don't
+ generally produce trivial results.
+ """
def setUp(self):
self.builder = PythagorasTree()
def test_get_point_on_square_edge_with_no_angle(self):
+ """Get central point on square edge that is not angled."""
point = self.builder._get_point_on_square_edge(
center=Point(0, 0), length=2, angle=0
)
@@ -22,6 +29,7 @@ def test_get_point_on_square_edge_with_no_angle(self):
self.assertAlmostEqual(point.y, expected_point.y, places=1)
def test_get_point_on_square_edge_with_non_zero_angle(self):
+ """Get central point on square edge that has angle"""
point = self.builder._get_point_on_square_edge(
center=Point(2.7, 2.77), length=1.65, angle=math.radians(20.97)
)
@@ -30,6 +38,8 @@ def test_get_point_on_square_edge_with_non_zero_angle(self):
self.assertAlmostEqual(point.y, expected_point.y, places=1)
def test_compute_center_with_simple_square_angle(self):
+ """Compute the center of the square in the next step given a right
+ angle."""
initial_square = Square(Point(0, 0), length=2, angle=math.pi / 2)
point = self.builder._compute_center(
initial_square, length=1.13, alpha=math.radians(68.57))
@@ -38,6 +48,8 @@ def test_compute_center_with_simple_square_angle(self):
self.assertAlmostEqual(point.y, expected_point.y, places=1)
def test_compute_center_with_complex_square_angle(self):
+ """Compute the center of the square in the next step given a more
+ complex angle."""
initial_square = Square(
Point(1.5, 1.5), length=2.24, angle=math.radians(63.43)
)
@@ -48,6 +60,9 @@ def test_compute_center_with_complex_square_angle(self):
self.assertAlmostEqual(point.y, expected_point.y, places=1)
def test_compute_center_with_complex_square_angle_with_base_angle(self):
+ """Compute the center of the square in the next step when there is a
+ base angle - when the square does not touch the base square on the left
+ edge."""
initial_square = Square(
Point(1.5, 1.5), length=2.24, angle=math.radians(63.43)
)
diff --git a/Orange/widgets/visualize/widgetutils/common/owgrid.py b/Orange/widgets/visualize/widgetutils/common/owgrid.py
index 6e091288177..e6656e9ff40 100644
--- a/Orange/widgets/visualize/widgetutils/common/owgrid.py
+++ b/Orange/widgets/visualize/widgetutils/common/owgrid.py
@@ -1,3 +1,9 @@
+"""Grid widget.
+
+Positions items into a grid. This has been tested with widgets that have their
+`boundingBox` and `sizeHint` methods properly defined.
+
+"""
from itertools import zip_longest
from PyQt4 import QtGui, QtCore
@@ -84,11 +90,14 @@ def paint(self, painter, options, widget=None):
class ZoomableGridItem(GridItem):
"""Makes a grid item "zoomable" through the `set_max_size` method.
+ "Zoomable" here means there is a `Zoom` slider through which the grid items
+ can be made larger and smaller in the grid.
+
Notes
-----
- .. Note:: This grid item will override any bounding box or size hint
+ .. note:: This grid item will override any bounding box or size hint
defined in the class hierarchy with its own.
- .. Note:: This makes the grid item square.
+ .. note:: This makes the grid item square.
Parameters
----------
diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/widgetutils/common/owlegend.py
index 1047d7a43b0..47f90f27cea 100644
--- a/Orange/widgets/visualize/widgetutils/common/owlegend.py
+++ b/Orange/widgets/visualize/widgetutils/common/owlegend.py
@@ -1,12 +1,24 @@
-"""
-Legend classes to use with `QGraphicsScene` objects.
-"""
+"""Legend classes to use with `QGraphicsScene` objects."""
import numpy as np
from PyQt4 import QtGui, QtCore
from PyQt4.QtCore import Qt
class Anchorable(QtGui.QGraphicsWidget):
+ """Anchorable base class.
+
+ Subclassing the `Anchorable` class will anchor the given
+ `QGraphicsWidget` to a position on the viewport. This does require you to
+ use the `AnchorableGraphicsView` class, it is made to be composable, so
+ that should not be a problem.
+
+ Notes
+ -----
+ .. note:: Subclassing this class will not make your widget movable, you
+ have to do that yourself. If you do make your widget movable, this will
+ handle any further positioning when the widget is moved.
+
+ """
__corners = ['topLeft', 'topRight', 'bottomLeft', 'bottomRight']
TOP_LEFT, TOP_RIGHT, BOTTOM_LEFT, BOTTOM_RIGHT = __corners
@@ -27,30 +39,30 @@ def __init__(self, parent=None, corner='bottomRight', offset=(10, 10)):
elif isinstance(offset, QtCore.QPoint):
self.__offset = offset
- def moveEvent(self, ev):
- super().moveEvent(ev)
+ def moveEvent(self, event):
+ super().moveEvent(event)
# This check is needed because simply resizing the window will cause
# the item to move and trigger a `moveEvent` therefore we need to check
# that the movement was done intentionally by the user using the mouse
if QtGui.QApplication.mouseButtons() == Qt.LeftButton:
self.recalculate_offset()
- def resizeEvent(self, ev):
+ def resizeEvent(self, event):
# When the item is first shown, we need to update its position
- super().resizeEvent(ev)
+ super().resizeEvent(event)
if not self.__has_been_drawn:
self.__offset = self.__calculate_actual_offset(self.__offset)
self.update_pos()
self.__has_been_drawn = True
- def showEvent(self, ev):
+ def showEvent(self, event):
# When the item is first shown, we need to update its position
- super().showEvent(ev)
+ super().showEvent(event)
self.update_pos()
def recalculate_offset(self):
- # This is called whenever the item is being moved and needs to
- # recalculate its offset
+ """This is called whenever the item is being moved and needs to
+ recalculate its offset."""
view = self.__get_view()
# Get the view box and position of legend relative to the view,
# not the scene
@@ -63,9 +75,11 @@ def recalculate_offset(self):
self.__offset = viewbox_corner - pos
def update_pos(self):
- # This is called whenever something happened with the view that caused
- # this item to move from its anchored position, so we have to adjust
- # the position to maintain the effect of being anchored
+ """Update the widget position relative to the viewport.
+
+ This is called whenever something happened with the view that caused
+ this item to move from its anchored position, so we have to adjust the
+ position to maintain the effect of being anchored."""
view = self.__get_view()
if self.__corner_str and view is not None:
box = self.__usable_viewbox()
@@ -125,45 +139,48 @@ def __usable_viewbox(self):
view = self.__get_view()
if view.horizontalScrollBar().isVisible():
- h = view.horizontalScrollBar().size().height()
+ height = view.horizontalScrollBar().size().height()
else:
- h = 0
+ height = 0
if view.verticalScrollBar().isVisible():
- w = view.verticalScrollBar().size().width()
+ width = view.verticalScrollBar().size().width()
else:
- w = 0
+ width = 0
- size = view.size() - QtCore.QSize(w, h)
+ size = view.size() - QtCore.QSize(width, height)
return QtCore.QRect(QtCore.QPoint(0, 0), size)
class AnchorableGraphicsView(QtGui.QGraphicsView):
+ """Subclass when wanting to use Anchorable items in your view."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
+ # Handle scroll bar hiding or showing
self.horizontalScrollBar().valueChanged.connect(
self.update_anchored_items)
self.verticalScrollBar().valueChanged.connect(
self.update_anchored_items)
- def resizeEvent(self, ev):
- super().resizeEvent(ev)
+ def resizeEvent(self, event):
+ super().resizeEvent(event)
self.update_anchored_items()
- def mousePressEvent(self, ev):
- super().mousePressEvent(ev)
+ def mousePressEvent(self, event):
+ super().mousePressEvent(event)
self.update_anchored_items()
- def wheelEvent(self, ev):
- super().wheelEvent(ev)
+ def wheelEvent(self, event):
+ super().wheelEvent(event)
self.update_anchored_items()
- def mouseMoveEvent(self, ev):
- super().mouseMoveEvent(ev)
+ def mouseMoveEvent(self, event):
+ super().mouseMoveEvent(event)
self.update_anchored_items()
def update_anchored_items(self):
+ """Update all the items that subclass the `Anchorable` class."""
for item in self.__anchorable_items():
item.update_pos()
@@ -172,6 +189,9 @@ def __anchorable_items(self):
class ColorIndicator(QtGui.QGraphicsWidget):
+ """Base class for an item indicator.
+
+ Usually the little square or circle in the legend in front of the text."""
pass
@@ -320,7 +340,7 @@ class LegendGradient(QtGui.QGraphicsWidget):
Notes
-----
- .. Note:: While the gradient does support any number of colors, any more
+ .. note:: While the gradient does support any number of colors, any more
than 3 is not very readable. This should not be a problem, since Orange
only implements 2 or 3 colors.
@@ -348,11 +368,11 @@ def __init__(self, palette, parent, orientation):
# Get the appropriate rectangle dimensions based on orientation
if orientation == Qt.Vertical:
- w, h = self.GRADIENT_WIDTH, self.GRADIENT_HEIGHT
+ width, height = self.GRADIENT_WIDTH, self.GRADIENT_HEIGHT
elif orientation == Qt.Horizontal:
- w, h = self.GRADIENT_HEIGHT, self.GRADIENT_WIDTH
+ width, height = self.GRADIENT_HEIGHT, self.GRADIENT_WIDTH
- self.__rect_item = QtGui.QGraphicsRectItem(0, 0, w, h, self)
+ self.__rect_item = QtGui.QGraphicsRectItem(0, 0, width, height, self)
self.__rect_item.setPen(QtGui.QPen(QtGui.QColor(0, 0, 0, 0)))
self.__rect_item.setBrush(QtGui.QBrush(self.__gradient))
@@ -418,13 +438,10 @@ def _format_values(values):
class Legend(Anchorable):
"""Base legend class.
- This class provides common attributes for any legend derivates:
+ This class provides common attributes for any legend subclasses:
- Behaviour on `QGraphicsScene`
- Appearance of legend
- If you have access to the `domain` property, the `LegendBuilder` class
- can be used to automatically build a legend for you.
-
Parameters
----------
parent : QtGui.QGraphicsItem, optional
@@ -447,7 +464,7 @@ class Legend(Anchorable):
Notes
-----
- .. Warning:: If the domain parameter is supplied, the items parameter will
+ .. warning:: If the domain parameter is supplied, the items parameter will
be ignored.
"""
@@ -457,6 +474,7 @@ def __init__(self, parent=None, orientation=Qt.Vertical, domain=None,
font=None, color_indicator_cls=LegendItemSquare, **kwargs):
super().__init__(parent, **kwargs)
+ self._layout = None
self.orientation = orientation
self.bg_color = QtGui.QBrush(bg_color)
self.color_indicator_cls = color_indicator_cls
@@ -510,7 +528,7 @@ def set_domain(self, domain):
If the domain does not contain the correct type of class variable.
"""
- raise NotImplemented()
+ raise NotImplementedError()
def set_items(self, values):
"""Handle receiving an array of items.
@@ -523,7 +541,7 @@ def set_items(self, values):
-------
"""
- raise NotImplemented()
+ raise NotImplementedError()
@staticmethod
def _convert_to_color(obj):
@@ -587,7 +605,7 @@ class OWContinuousLegend(Legend):
OWDiscreteLegend
"""
-
+
def __init__(self, *args, **kwargs):
# Variables used in the `set_` methods must be set before calling super
self.__range = kwargs.get('range', ())
@@ -637,6 +655,19 @@ def set_items(self, values):
class OWBinnedContinuousLegend(Legend):
+ """Binned continuous legend in case you don't like gradients.
+
+ This is not implemented yet, but in case it ever needs to be, the stub is
+ available.
+
+ See Also
+ --------
+ Legend
+ OWDiscreteLegend
+ OWContinuousLegend
+
+ """
+
def set_domain(self, domain):
pass
diff --git a/Orange/widgets/visualize/widgetutils/common/scene.py b/Orange/widgets/visualize/widgetutils/common/scene.py
index ca10a1993fe..7c68f146e83 100644
--- a/Orange/widgets/visualize/widgetutils/common/scene.py
+++ b/Orange/widgets/visualize/widgetutils/common/scene.py
@@ -1,3 +1,4 @@
+"""Common QGraphicsScene components that can be composed when needed."""
from PyQt4 import QtGui
@@ -9,8 +10,8 @@ class UpdateItemsOnSelectGraphicsScene(QtGui.QGraphicsScene):
Notes
-----
- ..Note:: I suspect this is completely unncessary, but have not been able to
- find a reasonable way to keep the selection logic inside the actual
+ .. note:: I suspect this is completely unncessary, but have not been able
+ to find a reasonable way to keep the selection logic inside the actual
`QGraphicsItem` objects
"""
diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/widgetutils/common/view.py
index 1803c9512f2..e318e510bdb 100644
--- a/Orange/widgets/visualize/widgetutils/common/view.py
+++ b/Orange/widgets/visualize/widgetutils/common/view.py
@@ -1,3 +1,5 @@
+"""Common useful `QGraphicsView` classes that can be composed to achieve
+desired functionality."""
from itertools import repeat
import numpy as np
@@ -25,7 +27,10 @@ class ZoomableGraphicsView(QtGui.QGraphicsView):
Notes
-----
- - This view will consume wheel scrolling and right mouse click events.
+ .. note:: This view will NOT consume the wheel event, so it would be wise
+ to use this component in conjuction with the `PreventDefaultWheelEvent`
+ in most cases.
+ .. note:: This view does however consume the right mouse click event.
"""
@@ -76,8 +81,8 @@ def __set_padding(self, padding):
else:
padding = 0, 0, 0, 0
- l, t, r, b = padding
- self.__padding = -l, -t, r, b
+ left, top, right, bottom = padding
+ self.__padding = -left, -top, right, bottom
def __handle_zoom(self, direction):
"""Handle zoom event, direction is positive if zooming in, otherwise
@@ -106,6 +111,16 @@ def __reset_zoomout_limit(self):
self.__zoomout_limit_reached = False
def set_central_widget(self, widget):
+ """Set the central widget in the view.
+
+ This means that the initial zoom will fit the central widget, and may
+ cut out any other widgets.
+
+ Parameters
+ ----------
+ widget : QGraphicsWidget
+
+ """
self.__central_widget = widget
def central_widget_rect(self):
@@ -176,5 +191,12 @@ def mouseReleaseEvent(self, event):
class PreventDefaultWheelEvent(QtGui.QGraphicsView):
+ """Prevent the default wheel event.
+
+ The default wheel event pans the view around, if using the
+ `ZoomableGraphicsView`, this will prevent that behaviour.
+
+ """
+
def wheelEvent(self, event):
event.accept()
diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py
index 587226a23d1..5be8a6262d3 100644
--- a/Orange/widgets/visualize/widgetutils/tree/rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/rules.py
@@ -1,3 +1,16 @@
+"""Rules for classification and regression trees.
+
+Tree visualisations usually need to show the rules of nodes, these classes make
+merging these rules simple (otherwise you have repeating rules e.g. `age < 3`
+and `age < 2` which can be merged into `age < 2`.
+
+Subclasses of the `Rule` class should provide a nice interface to merge rules
+together through the `merge_with` method. Of course, this should not be forced
+where it doesn't make sense e.g. merging a discrete rule (e.g.
+:math:`x \in \{red, blue, green\}`) and a continuous rule (e.g.
+:math:`x \leq 5`).
+
+"""
import warnings
@@ -25,7 +38,7 @@ class DiscreteRule(Rule):
Parameters
----------
attr_name : str
- eq : bool
+ equals : bool
Should indicate whether or not the rule equals the value or not.
value : object
@@ -39,14 +52,14 @@ class DiscreteRule(Rule):
Notes
-----
- .. Note:: Merging discrete rules is currently not implemented, the new rule
+ .. note:: Merging discrete rules is currently not implemented, the new rule
is simply returned and a warning is issued.
"""
- def __init__(self, attr_name, eq, value):
+ def __init__(self, attr_name, equals, value):
self.attr_name = attr_name
- self.eq = eq
+ self.equals = equals
self.value = value
def merge_with(self, rule):
@@ -57,7 +70,7 @@ def merge_with(self, rule):
def __str__(self):
return '{} {} {}'.format(
- self.attr_name, '=' if self.eq else '≠', self.value)
+ self.attr_name, '=' if self.equals else '≠', self.value)
class ContinuousRule(Rule):
@@ -66,7 +79,7 @@ class ContinuousRule(Rule):
Parameters
----------
attr_name : str
- gt : bool
+ greater : bool
Should indicate whether the variable must be greater than the value.
value : int
inclusive : bool, optional
@@ -83,14 +96,14 @@ class ContinuousRule(Rule):
Notes
-----
- .. Note:: Continuous rules can currently only be merged with other
+ .. note:: Continuous rules can currently only be merged with other
continuous rules.
"""
- def __init__(self, attr_name, gt, value, inclusive=False):
+ def __init__(self, attr_name, greater, value, inclusive=False):
self.attr_name = attr_name
- self.gt = gt
+ self.greater = greater
self.value = value
self.inclusive = inclusive
@@ -99,23 +112,23 @@ def merge_with(self, rule):
raise NotImplementedError('Continuous rules can currently only be '
'merged with other continuous rules')
# Handle when both have same sign
- if self.gt == rule.gt:
+ if self.greater == rule.greater:
# When both are GT
- if self.gt is True:
+ if self.greater is True:
larger = max(self.value, rule.value)
- return ContinuousRule(self.attr_name, self.gt, larger)
+ return ContinuousRule(self.attr_name, self.greater, larger)
# When both are LT
else:
smaller = self.value if self.value < rule.value else rule.value
- return ContinuousRule(self.attr_name, self.gt, smaller)
+ return ContinuousRule(self.attr_name, self.greater, smaller)
# When they have different signs we need to return an interval rule
else:
- lt_rule, gt_rule = (rule, self) if self.gt else (self, rule)
+ lt_rule, gt_rule = (rule, self) if self.greater else (self, rule)
return IntervalRule(self.attr_name, gt_rule, lt_rule)
def __str__(self):
return '%s %s %.3f' % (
- self.attr_name, '>' if self.gt else '≤', self.value)
+ self.attr_name, '>' if self.greater else '≤', self.value)
class IntervalRule(Rule):
@@ -138,7 +151,7 @@ class IntervalRule(Rule):
Notes
-----
- .. Note:: Currently, only cases which appear in classification and
+ .. note:: Currently, only cases which appear in classification and
regression trees are implemented. An interval can not be made up of two
parts (e.g. (-∞, -1) ∪ (1, ∞)).
@@ -160,7 +173,7 @@ def __init__(self, attr_name, left_rule, right_rule):
def merge_with(self, rule):
if isinstance(rule, ContinuousRule):
- if rule.gt:
+ if rule.greater:
return IntervalRule(
self.attr_name, self.left_rule.merge_with(rule),
self.right_rule)
diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
index 759542cea18..980811decf3 100644
--- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
+++ b/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
@@ -1,3 +1,4 @@
+"""Tree adapter class for sklearn trees."""
from collections import OrderedDict
import numpy as np
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
index 0a8633b26f6..abf64c727de 100644
--- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
@@ -1,3 +1,4 @@
+"""Test rules for classification and regression trees."""
import unittest
from Orange.widgets.visualize.widgetutils.tree.rules import (
@@ -7,6 +8,13 @@
class TestRules(unittest.TestCase):
+ """Rules for classification and regression trees.
+
+ See Also
+ --------
+ Orange.widgets.visualize.widgetutils.tree.rules
+
+ """
# CONTINUOUS RULES
def test_merging_two_gt_continuous_rules(self):
"""Merging `x > 1` and `x > 2` should produce `x > 2`."""
@@ -124,6 +132,7 @@ def test_merging_interval_rules_generally(self):
# ALL RULES
def test_merge_commutativity_on_continuous_rules(self):
+ """Continuous rule merging should be commutative."""
rule1 = ContinuousRule('Rule1', True, 1)
rule2 = ContinuousRule('Rule1', True, 2)
new_rule1 = rule1.merge_with(rule2)
@@ -131,6 +140,7 @@ def test_merge_commutativity_on_continuous_rules(self):
self.assertEqual(new_rule1.value, new_rule2.value)
def test_merge_commutativity_on_interval_rules(self):
+ """Interval rule merging should be commutative."""
rule1 = IntervalRule('Rule',
ContinuousRule('Rule', True, 0),
ContinuousRule('Rule', False, 4))
@@ -145,12 +155,14 @@ def test_merge_commutativity_on_interval_rules(self):
new_rule2.right_rule.value)
def test_merge_keeps_gt_on_continuous_rules(self):
+ """Merging ccontinuous rules should keep GT property."""
rule1 = ContinuousRule('Rule1', True, 1)
rule2 = ContinuousRule('Rule1', True, 2)
new_rule = rule1.merge_with(rule2)
self.assertEqual(new_rule.gt, True)
def test_merge_keeps_attr_name_on_continuous_rules(self):
+ """Merging continuous rules should keep the name of the rule."""
rule1 = ContinuousRule('Rule1', True, 1)
rule2 = ContinuousRule('Rule1', True, 2)
new_rule = rule1.merge_with(rule2)
diff --git a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py
index 7bf07368c5f..8d78b392a27 100644
--- a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py
+++ b/Orange/widgets/visualize/widgetutils/tree/treeadapter.py
@@ -1,3 +1,4 @@
+"""Base tree adapter class with common methods needed for visualisations."""
from abc import ABCMeta, abstractmethod
From 0732d822487add9f675c4c073a9edd46c4cdf185 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Tue, 12 Jul 2016 12:24:39 +0200
Subject: [PATCH 05/12] Pythagorean tree: Add more pylint fixes
---
Orange/misc/cache.py | 12 +++++-----
Orange/widgets/visualize/owpythagorastree.py | 22 ++++++++++++-------
.../widgets/visualize/owpythagoreanforest.py | 1 +
.../visualize/tests/test_owpythagorastree.py | 1 +
.../visualize/widgetutils/common/owlegend.py | 14 +++++++-----
.../visualize/widgetutils/tree/rules.py | 2 +-
6 files changed, 32 insertions(+), 20 deletions(-)
diff --git a/Orange/misc/cache.py b/Orange/misc/cache.py
index 6d2d3f3baa7..adc10f5c5b2 100644
--- a/Orange/misc/cache.py
+++ b/Orange/misc/cache.py
@@ -1,24 +1,26 @@
-"""Common caching methods, using `lru_cahce` sometimes has its downsides."""
+"""Common caching methods, using `lru_cache` sometimes has its downsides."""
from functools import wraps, lru_cache
import weakref
-def single_cache(f):
+def single_cache(func):
+ """Cache with size 1."""
last_args = ()
last_kwargs = set()
last_result = None
- def cached(*args, **kwargs):
+ @wraps(func)
+ def _cached(*args, **kwargs):
nonlocal last_args, last_kwargs, last_result
if len(last_args) != len(args) or \
not all(x is y for x, y in zip(args, last_args)) or \
last_kwargs != set(kwargs) or \
any(last_kwargs[k] != kwargs[k] for k in last_kwargs):
- last_result = f(*args, **kwargs)
+ last_result = func(*args, **kwargs)
last_args, last_kwargs = args, kwargs
return last_result
- return cached
+ return _cached
def memoize_method(*lru_args, **lru_kwargs):
diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py
index 0ba5b48306c..934a0808904 100644
--- a/Orange/widgets/visualize/owpythagorastree.py
+++ b/Orange/widgets/visualize/owpythagorastree.py
@@ -108,7 +108,7 @@ def __init__(self):
# CONTROL AREA
# Tree info area
box_info = gui.widgetBox(self.controlArea, 'Tree Info')
- self.info = gui.widgetLabel(box_info, label='')
+ self.info = gui.widgetLabel(box_info)
# Display settings area
box_display = gui.widgetBox(self.controlArea, 'Display Settings')
@@ -249,10 +249,12 @@ def update_depth(self):
self.ptree.set_depth_limit(self.depth_limit)
def update_colors(self):
+ """When the target class / node coloring needs to be updated."""
self.ptree.target_class_has_changed()
self._tree_specific('_update_legend_colors')()
def update_size_calc(self):
+ """When the tree size calculation is updated."""
self._update_log_scale_slider()
self.invalidate_tree()
@@ -265,6 +267,7 @@ def invalidate_tree(self):
self._update_main_area()
def update_tooltip_enabled(self):
+ """When the tooltip visibility is changed and need to be updated."""
if self.tooltips_enabled:
self.ptree.set_tooltip_func(
self._tree_specific('_get_tooltip')
@@ -274,6 +277,7 @@ def update_tooltip_enabled(self):
self.ptree.tooltip_has_changed()
def update_show_legend(self):
+ """When the legend visibility needs to be updated."""
self._update_legend_visibility()
# MODEL CHANGED CONTROL ELEMENTS UPDATE METHODS
@@ -350,6 +354,7 @@ def commit(self):
self.send('Selected Data', data)
def send_report(self):
+ """Send report."""
self.report_plot()
def _tree_specific(self, method):
@@ -449,12 +454,13 @@ def _classification_get_tooltip(self, node):
return '
' \
+ text \
+ '{}/{} samples ({:2.3f}%)'.format(
- int(samples), total, ratio * 100) \
+ int(samples), total, ratio * 100) \
+ '
' \
+ ('Split by ' + splitting_attr.name
if not self.tree_adapter.is_leaf(node.label) else '') \
- + ('
' if len(rules) and not self.tree_adapter.is_leaf(
- node.label) else '') \
+ + ('
'
+ if len(rules) and not self.tree_adapter.is_leaf(node.label)
+ else '') \
+ rules_str \
+ ''
@@ -566,10 +572,10 @@ def _regression_get_tooltip(self, node):
class TreeGraphicsView(
- PannableGraphicsView,
- ZoomableGraphicsView,
- AnchorableGraphicsView,
- PreventDefaultWheelEvent
+ PannableGraphicsView,
+ ZoomableGraphicsView,
+ AnchorableGraphicsView,
+ PreventDefaultWheelEvent
):
"""QGraphicsView that contains all functionality we will use to display
tree."""
diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py
index e53b9c25c6c..08bec0c1ea5 100644
--- a/Orange/widgets/visualize/owpythagoreanforest.py
+++ b/Orange/widgets/visualize/owpythagoreanforest.py
@@ -290,6 +290,7 @@ def commit(self):
self.send('Tree', obj)
def send_report(self):
+ """Send report."""
self.report_plot()
def _update_scene_rect(self):
diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py
index 4015bef0a3e..555ee70d0ed 100644
--- a/Orange/widgets/visualize/tests/test_owpythagorastree.py
+++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py
@@ -9,6 +9,7 @@
)
+# pylint: disable=protected-access
class TestPythagorasTree(unittest.TestCase):
"""Pythagorean tree testing, make sure calculating square positions works
properly.
diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/widgetutils/common/owlegend.py
index 47f90f27cea..38c4a005822 100644
--- a/Orange/widgets/visualize/widgetutils/common/owlegend.py
+++ b/Orange/widgets/visualize/widgetutils/common/owlegend.py
@@ -92,15 +92,17 @@ def __calculate_actual_offset(self, offset):
actual offset from the top left corner of the item so positioning can
be done correctly."""
off_x, off_y = offset.x(), offset.y()
- w, h = self.boundingRect().width(), self.boundingRect().height()
+ width = self.boundingRect().width()
+ height = self.boundingRect().height()
+
if self.__corner_str == self.TOP_LEFT:
return QtCore.QPoint(-off_x, -off_y)
elif self.__corner_str == self.TOP_RIGHT:
- return QtCore.QPoint(off_x + w, -off_y)
+ return QtCore.QPoint(off_x + width, -off_y)
elif self.__corner_str == self.BOTTOM_RIGHT:
- return QtCore.QPoint(off_x + w, off_y + h)
+ return QtCore.QPoint(off_x + width, off_y + height)
elif self.__corner_str == self.BOTTOM_LEFT:
- return QtCore.QPoint(-off_x, off_y + h)
+ return QtCore.QPoint(-off_x, off_y + height)
def __get_closest_corner(self):
view = self.__get_view()
@@ -110,9 +112,9 @@ def __get_closest_corner(self):
legend_box = QtCore.QRect(pos, self.size().toSize())
view_box = QtCore.QRect(QtCore.QPoint(0, 0), view.size())
- def distance(t1, t2):
+ def distance(pt1, pt2):
# 2d euclidean distance
- return np.sqrt((t1.x() - t2.x()) ** 2 + (t1.y() - t2.y()) ** 2)
+ return np.sqrt((pt1.x() - pt2.x()) ** 2 + (pt1.y() - pt2.y()) ** 2)
distances = [
(distance(getattr(view_box, corner)(),
diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py
index 5be8a6262d3..a384bccaed4 100644
--- a/Orange/widgets/visualize/widgetutils/tree/rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/rules.py
@@ -1,4 +1,4 @@
-"""Rules for classification and regression trees.
+r"""Rules for classification and regression trees.
Tree visualisations usually need to show the rules of nodes, these classes make
merging these rules simple (otherwise you have repeating rules e.g. `age < 3`
From d98c37d7946f2947fa769f5f747ec45fc314ea55 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Tue, 12 Jul 2016 12:29:01 +0200
Subject: [PATCH 06/12] Pythagorean tree: Fix tests that failed due to pylint
induced renaming of variables
---
Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
index abf64c727de..a098b60af41 100644
--- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
@@ -159,7 +159,7 @@ def test_merge_keeps_gt_on_continuous_rules(self):
rule1 = ContinuousRule('Rule1', True, 1)
rule2 = ContinuousRule('Rule1', True, 2)
new_rule = rule1.merge_with(rule2)
- self.assertEqual(new_rule.gt, True)
+ self.assertEqual(new_rule.greater, True)
def test_merge_keeps_attr_name_on_continuous_rules(self):
"""Merging continuous rules should keep the name of the rule."""
From 266c8805972263a2c341a0b2c6bfb6ec6fd0fd64 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Wed, 13 Jul 2016 12:33:53 +0200
Subject: [PATCH 07/12] Pythagorean tree: Add repr to tree rules and fix
docstring example ordering
---
.../visualize/widgetutils/tree/rules.py | 38 +++++++++++--------
1 file changed, 23 insertions(+), 15 deletions(-)
diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/widgetutils/tree/rules.py
index a384bccaed4..9a1f84fad75 100644
--- a/Orange/widgets/visualize/widgetutils/tree/rules.py
+++ b/Orange/widgets/visualize/widgetutils/tree/rules.py
@@ -44,11 +44,11 @@ class DiscreteRule(Rule):
Examples
--------
- Age = 30
- >>> rule = DiscreteRule('age', True, 30)
+ >>> DiscreteRule('age', True, 30)
+ age = 30
- Name ≠ John
- >>> rule = DiscreteRule('name', False, 'John')
+ >>> DiscreteRule('name', False, 'John')
+ name ≠ John
Notes
-----
@@ -72,6 +72,8 @@ def __str__(self):
return '{} {} {}'.format(
self.attr_name, '=' if self.equals else '≠', self.value)
+ __repr__ = __str__
+
class ContinuousRule(Rule):
"""Continuous rule class for handling numeric rules.
@@ -88,11 +90,11 @@ class ContinuousRule(Rule):
Examples
--------
- x ≤ 30
- >>> rule = ContinuousRule('age', False, 30, inclusive=True)
+ >>> ContinuousRule('age', False, 30, inclusive=True)
+ age ≤ 30.000
- x > 30
- >>> rule = ContinuousRule('age', True, 30)
+ >>> ContinuousRule('age', True, 30)
+ age > 30.000
Notes
-----
@@ -130,6 +132,8 @@ def __str__(self):
return '%s %s %.3f' % (
self.attr_name, '>' if self.greater else '≤', self.value)
+ __repr__ = __str__
+
class IntervalRule(Rule):
"""Interval rule class for ranges of continuous values.
@@ -144,10 +148,10 @@ class IntervalRule(Rule):
Examples
--------
- 1 ≤ x < 3
- >>> rule = IntervalRule('Rule',
- >>> ContinuousRule('Rule', True, 1, inclusive=True),
- >>> ContinuousRule('Rule', False, 3))
+ >>> IntervalRule('Rule',
+ >>> ContinuousRule('Rule', True, 1, inclusive=True),
+ >>> ContinuousRule('Rule', False, 3))
+ Rule ∈ [1.000, 3.000)
Notes
-----
@@ -189,8 +193,12 @@ def merge_with(self, rule):
self.right_rule.merge_with(rule.right_rule))
def __str__(self):
- return '{} ∈ {}{:.3}, {:.3}{}'.format(
+ return '%s ∈ %s%.3f, %.3f%s' % (
self.attr_name,
- '[' if self.left_rule.inclusive else '(', self.left_rule.value,
- self.right_rule.value, ']' if self.right_rule.inclusive else ')'
+ '[' if self.left_rule.inclusive else '(',
+ self.left_rule.value,
+ self.right_rule.value,
+ ']' if self.right_rule.inclusive else ')'
)
+
+ __repr__ = __str__
From 04ec528215da48723512a4f188512fb56aa7b543 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Fri, 15 Jul 2016 11:42:19 +0200
Subject: [PATCH 08/12] Misc: Add tests for cache
---
Orange/tests/test_misc.py | 30 ++++++++++++++++++++++++++++++
1 file changed, 30 insertions(+)
create mode 100644 Orange/tests/test_misc.py
diff --git a/Orange/tests/test_misc.py b/Orange/tests/test_misc.py
new file mode 100644
index 00000000000..bf5b59f20da
--- /dev/null
+++ b/Orange/tests/test_misc.py
@@ -0,0 +1,30 @@
+import unittest
+
+from Orange.misc.cache import memoize_method, single_cache
+
+
+class Calculator:
+ @memoize_method()
+ def my_sum(self, *nums):
+ return sum(nums)
+
+
+@single_cache
+def my_sum(*nums):
+ return sum(nums)
+
+
+class TestCache(unittest.TestCase):
+
+ def test_single_cache(self):
+ self.assertEqual(my_sum(1, 2, 3, 4, 5), 15)
+ self.assertEqual(my_sum(1, 2, 3, 4, 5), 15)
+ # Make sure different args produce different results
+ self.assertEqual(my_sum(1, 2, 3, 4), 10)
+
+ def test_memoize_method(self):
+ calc = Calculator()
+ self.assertEqual(calc.my_sum(1, 2, 3, 4, 5), 15)
+ self.assertEqual(calc.my_sum(1, 2, 3, 4, 5), 15)
+ # Make sure different args produce different results
+ self.assertEqual(calc.my_sum(1, 2, 3, 4), 10)
From e03d4e8bf2a13737b69c5d6a1a9b699a2b165801 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Fri, 15 Jul 2016 11:53:08 +0200
Subject: [PATCH 09/12] Pythagorean tree: Move util classes to util folder and
remove widgetutils
---
Orange/widgets/visualize/owpythagorastree.py | 10 +++++-----
Orange/widgets/visualize/owpythagoreanforest.py | 4 ++--
Orange/widgets/visualize/pythagorastreeviewer.py | 2 +-
.../widgets/visualize/{utils.py => utils/__init__.py} | 0
.../visualize/{widgetutils/common => utils}/owgrid.py | 0
.../{widgetutils/common => utils}/owlegend.py | 0
.../visualize/{widgetutils/common => utils}/scene.py | 0
.../visualize/{widgetutils => utils/tree}/__init__.py | 0
.../visualize/{widgetutils => utils}/tree/rules.py | 0
.../{widgetutils => utils}/tree/skltreeadapter.py | 5 ++---
.../common => utils/tree/tests}/__init__.py | 0
.../{widgetutils => utils}/tree/tests/test_rules.py | 2 +-
.../{widgetutils => utils}/tree/treeadapter.py | 0
.../visualize/{widgetutils/common => utils}/view.py | 0
Orange/widgets/visualize/widgetutils/tree/__init__.py | 0
.../visualize/widgetutils/tree/tests/__init__.py | 0
16 files changed, 11 insertions(+), 12 deletions(-)
rename Orange/widgets/visualize/{utils.py => utils/__init__.py} (100%)
rename Orange/widgets/visualize/{widgetutils/common => utils}/owgrid.py (100%)
rename Orange/widgets/visualize/{widgetutils/common => utils}/owlegend.py (100%)
rename Orange/widgets/visualize/{widgetutils/common => utils}/scene.py (100%)
rename Orange/widgets/visualize/{widgetutils => utils/tree}/__init__.py (100%)
rename Orange/widgets/visualize/{widgetutils => utils}/tree/rules.py (100%)
rename Orange/widgets/visualize/{widgetutils => utils}/tree/skltreeadapter.py (98%)
rename Orange/widgets/visualize/{widgetutils/common => utils/tree/tests}/__init__.py (100%)
rename Orange/widgets/visualize/{widgetutils => utils}/tree/tests/test_rules.py (99%)
rename Orange/widgets/visualize/{widgetutils => utils}/tree/treeadapter.py (100%)
rename Orange/widgets/visualize/{widgetutils/common => utils}/view.py (100%)
delete mode 100644 Orange/widgets/visualize/widgetutils/tree/__init__.py
delete mode 100644 Orange/widgets/visualize/widgetutils/tree/tests/__init__.py
diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py
index 934a0808904..23619d62af0 100644
--- a/Orange/widgets/visualize/owpythagorastree.py
+++ b/Orange/widgets/visualize/owpythagorastree.py
@@ -19,15 +19,13 @@
from math import sqrt, log
import numpy as np
-from Orange.widgets.visualize.widgetutils.common.scene import \
+from Orange.widgets.visualize.utils.scene import \
UpdateItemsOnSelectGraphicsScene
-from Orange.widgets.visualize.widgetutils.common.view import (
+from Orange.widgets.visualize.utils.view import (
PannableGraphicsView,
ZoomableGraphicsView,
PreventDefaultWheelEvent
)
-from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \
- SklTreeAdapter
from PyQt4 import QtGui
from Orange.base import Tree
@@ -40,12 +38,14 @@
PythagorasTreeViewer,
SquareGraphicsItem
)
-from Orange.widgets.visualize.widgetutils.common.owlegend import (
+from Orange.widgets.visualize.utils.owlegend import (
AnchorableGraphicsView,
Anchorable,
OWDiscreteLegend,
OWContinuousLegend
)
+from Orange.widgets.visualize.utils.tree.skltreeadapter import \
+ SklTreeAdapter
from Orange.widgets.widget import OWWidget
diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py
index 08bec0c1ea5..6cd51f5b26d 100644
--- a/Orange/widgets/visualize/owpythagoreanforest.py
+++ b/Orange/widgets/visualize/owpythagoreanforest.py
@@ -14,12 +14,12 @@
from Orange.widgets import gui, settings
from Orange.widgets.utils.colorpalette import ContinuousPaletteGenerator
from Orange.widgets.visualize.pythagorastreeviewer import PythagorasTreeViewer
-from Orange.widgets.visualize.widgetutils.common.owgrid import (
+from Orange.widgets.visualize.utils.owgrid import (
OWGrid,
SelectableGridItem,
ZoomableGridItem
)
-from Orange.widgets.visualize.widgetutils.tree.skltreeadapter import \
+from Orange.widgets.visualize.utils.tree.skltreeadapter import \
SklTreeAdapter
from Orange.widgets.widget import OWWidget
diff --git a/Orange/widgets/visualize/pythagorastreeviewer.py b/Orange/widgets/visualize/pythagorastreeviewer.py
index 41d4ac12d3e..25de8a588f7 100644
--- a/Orange/widgets/visualize/pythagorastreeviewer.py
+++ b/Orange/widgets/visualize/pythagorastreeviewer.py
@@ -36,7 +36,7 @@ class PythagorasTreeViewer(QtGui.QGraphicsWidget):
Examples
--------
- >>> from Orange.widgets.visualize.widgetutils.tree.treeadapter import (
+ >>> from Orange.widgets.visualize.utils.tree.treeadapter import (
>>> TreeAdapter
>>> )
Pass tree through constructor.
diff --git a/Orange/widgets/visualize/utils.py b/Orange/widgets/visualize/utils/__init__.py
similarity index 100%
rename from Orange/widgets/visualize/utils.py
rename to Orange/widgets/visualize/utils/__init__.py
diff --git a/Orange/widgets/visualize/widgetutils/common/owgrid.py b/Orange/widgets/visualize/utils/owgrid.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/common/owgrid.py
rename to Orange/widgets/visualize/utils/owgrid.py
diff --git a/Orange/widgets/visualize/widgetutils/common/owlegend.py b/Orange/widgets/visualize/utils/owlegend.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/common/owlegend.py
rename to Orange/widgets/visualize/utils/owlegend.py
diff --git a/Orange/widgets/visualize/widgetutils/common/scene.py b/Orange/widgets/visualize/utils/scene.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/common/scene.py
rename to Orange/widgets/visualize/utils/scene.py
diff --git a/Orange/widgets/visualize/widgetutils/__init__.py b/Orange/widgets/visualize/utils/tree/__init__.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/__init__.py
rename to Orange/widgets/visualize/utils/tree/__init__.py
diff --git a/Orange/widgets/visualize/widgetutils/tree/rules.py b/Orange/widgets/visualize/utils/tree/rules.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/tree/rules.py
rename to Orange/widgets/visualize/utils/tree/rules.py
diff --git a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py b/Orange/widgets/visualize/utils/tree/skltreeadapter.py
similarity index 98%
rename from Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
rename to Orange/widgets/visualize/utils/tree/skltreeadapter.py
index 980811decf3..3aed7f88680 100644
--- a/Orange/widgets/visualize/widgetutils/tree/skltreeadapter.py
+++ b/Orange/widgets/visualize/utils/tree/skltreeadapter.py
@@ -2,12 +2,11 @@
from collections import OrderedDict
import numpy as np
+from Orange.widgets.visualize.utils.tree.treeadapter import TreeAdapter
from Orange.misc.cache import memoize_method
-from Orange.widgets.visualize.widgetutils.tree.treeadapter import TreeAdapter
-
from Orange.preprocess.transformation import Indicator
-from Orange.widgets.visualize.widgetutils.tree.rules import (
+from Orange.widgets.visualize.utils.tree.rules import (
DiscreteRule,
ContinuousRule
)
diff --git a/Orange/widgets/visualize/widgetutils/common/__init__.py b/Orange/widgets/visualize/utils/tree/tests/__init__.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/common/__init__.py
rename to Orange/widgets/visualize/utils/tree/tests/__init__.py
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py b/Orange/widgets/visualize/utils/tree/tests/test_rules.py
similarity index 99%
rename from Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
rename to Orange/widgets/visualize/utils/tree/tests/test_rules.py
index a098b60af41..562d926eae4 100644
--- a/Orange/widgets/visualize/widgetutils/tree/tests/test_rules.py
+++ b/Orange/widgets/visualize/utils/tree/tests/test_rules.py
@@ -1,7 +1,7 @@
"""Test rules for classification and regression trees."""
import unittest
-from Orange.widgets.visualize.widgetutils.tree.rules import (
+from Orange.widgets.visualize.utils.tree.rules import (
ContinuousRule,
IntervalRule,
)
diff --git a/Orange/widgets/visualize/widgetutils/tree/treeadapter.py b/Orange/widgets/visualize/utils/tree/treeadapter.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/tree/treeadapter.py
rename to Orange/widgets/visualize/utils/tree/treeadapter.py
diff --git a/Orange/widgets/visualize/widgetutils/common/view.py b/Orange/widgets/visualize/utils/view.py
similarity index 100%
rename from Orange/widgets/visualize/widgetutils/common/view.py
rename to Orange/widgets/visualize/utils/view.py
diff --git a/Orange/widgets/visualize/widgetutils/tree/__init__.py b/Orange/widgets/visualize/widgetutils/tree/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py b/Orange/widgets/visualize/widgetutils/tree/tests/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
From d72ebceb951e910d288e54a41dae562fba0b4c27 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Fri, 15 Jul 2016 15:08:27 +0200
Subject: [PATCH 10/12] Pythagorean tree: Added repr for rules and added prints
to docstring
---
Orange/widgets/visualize/utils/tree/rules.py | 27 ++++++++++++--------
1 file changed, 17 insertions(+), 10 deletions(-)
diff --git a/Orange/widgets/visualize/utils/tree/rules.py b/Orange/widgets/visualize/utils/tree/rules.py
index 9a1f84fad75..75aaac946ab 100644
--- a/Orange/widgets/visualize/utils/tree/rules.py
+++ b/Orange/widgets/visualize/utils/tree/rules.py
@@ -44,10 +44,10 @@ class DiscreteRule(Rule):
Examples
--------
- >>> DiscreteRule('age', True, 30)
+ >>> print(DiscreteRule('age', True, 30))
age = 30
- >>> DiscreteRule('name', False, 'John')
+ >>> print(DiscreteRule('name', False, 'John'))
name ≠ John
Notes
@@ -72,7 +72,9 @@ def __str__(self):
return '{} {} {}'.format(
self.attr_name, '=' if self.equals else '≠', self.value)
- __repr__ = __str__
+ def __repr__(self):
+ return "DiscreteRule(attr_name='%s', equals=%s, value=%s)" % (
+ self.attr_name, self.equals, self.value)
class ContinuousRule(Rule):
@@ -90,10 +92,10 @@ class ContinuousRule(Rule):
Examples
--------
- >>> ContinuousRule('age', False, 30, inclusive=True)
+ >>> print(ContinuousRule('age', False, 30, inclusive=True))
age ≤ 30.000
- >>> ContinuousRule('age', True, 30)
+ >>> print(ContinuousRule('age', True, 30))
age > 30.000
Notes
@@ -132,7 +134,10 @@ def __str__(self):
return '%s %s %.3f' % (
self.attr_name, '>' if self.greater else '≤', self.value)
- __repr__ = __str__
+ def __repr__(self):
+ return "ContinuousRule(attr_name='%s', greater=%s, value=%s, " \
+ "inclusive=%s)" % (self.attr_name, self.greater, self.value,
+ self.inclusive)
class IntervalRule(Rule):
@@ -148,9 +153,9 @@ class IntervalRule(Rule):
Examples
--------
- >>> IntervalRule('Rule',
- >>> ContinuousRule('Rule', True, 1, inclusive=True),
- >>> ContinuousRule('Rule', False, 3))
+ >>> print(IntervalRule('Rule',
+ >>> ContinuousRule('Rule', True, 1, inclusive=True),
+ >>> ContinuousRule('Rule', False, 3)))
Rule ∈ [1.000, 3.000)
Notes
@@ -201,4 +206,6 @@ def __str__(self):
']' if self.right_rule.inclusive else ')'
)
- __repr__ = __str__
+ def __repr__(self):
+ return "IntervalRule(attr_name='%s', left_rule=%s, right_rule=%s)" % (
+ self.attr_name, repr(self.left_rule), repr(self.right_rule))
From 3102f7f25e84360e4fc2b977ac7caee5b4b71ce6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Mon, 18 Jul 2016 09:44:52 +0200
Subject: [PATCH 11/12] Pythagorean tree: Add cosmetic fix
---
Orange/widgets/visualize/utils/tree/rules.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Orange/widgets/visualize/utils/tree/rules.py b/Orange/widgets/visualize/utils/tree/rules.py
index 75aaac946ab..a491da2893e 100644
--- a/Orange/widgets/visualize/utils/tree/rules.py
+++ b/Orange/widgets/visualize/utils/tree/rules.py
@@ -123,7 +123,7 @@ def merge_with(self, rule):
return ContinuousRule(self.attr_name, self.greater, larger)
# When both are LT
else:
- smaller = self.value if self.value < rule.value else rule.value
+ smaller = min(self.value, rule.value)
return ContinuousRule(self.attr_name, self.greater, smaller)
# When they have different signs we need to return an interval rule
else:
From 596a4526865c8c1a0195cda955ad64ff8ad37c65 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Pavlin=20Poli=C4=8Dar?=
Date: Tue, 19 Jul 2016 14:48:50 +0200
Subject: [PATCH 12/12] Pythagorean tree: Change text in target class with
regression, make forest name camel case
---
Orange/widgets/visualize/owpythagorastree.py | 8 ++++++++
Orange/widgets/visualize/owpythagoreanforest.py | 2 +-
2 files changed, 9 insertions(+), 1 deletion(-)
diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py
index 23619d62af0..ed7e1e43e6b 100644
--- a/Orange/widgets/visualize/owpythagorastree.py
+++ b/Orange/widgets/visualize/owpythagorastree.py
@@ -387,6 +387,10 @@ def _tree_specific(self, method):
# CLASSIFICATION TREE SPECIFIC METHODS
def _classification_update_target_class_combo(self):
self._clear_target_class_combo()
+ list(filter(
+ lambda x: isinstance(x, QtGui.QLabel),
+ self.target_class_combo.parent().children()
+ ))[0].setText('Target class')
self.target_class_combo.addItem('None')
values = [c.title() for c in
self.tree_adapter.domain.class_vars[0].values]
@@ -467,6 +471,10 @@ def _classification_get_tooltip(self, node):
# REGRESSION TREE SPECIFIC METHODS
def _regression_update_target_class_combo(self):
self._clear_target_class_combo()
+ list(filter(
+ lambda x: isinstance(x, QtGui.QLabel),
+ self.target_class_combo.parent().children()
+ ))[0].setText('Node color')
self.target_class_combo.addItems(
list(zip(*self.REGRESSION_COLOR_CALC))[0])
self.target_class_combo.setCurrentIndex(self.target_class_index)
diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py
index 6cd51f5b26d..f37ac8210c5 100644
--- a/Orange/widgets/visualize/owpythagoreanforest.py
+++ b/Orange/widgets/visualize/owpythagoreanforest.py
@@ -25,7 +25,7 @@
class OWPythagoreanForest(OWWidget):
- name = 'Pythagorean forest'
+ name = 'Pythagorean Forest'
description = 'Pythagorean forest for visualising random forests.'
icon = 'icons/PythagoreanForest.svg'