Skip to content

Commit

Permalink
Merge pull request #4234 from ales-erjavec/owheatmap-split-by
Browse files Browse the repository at this point in the history
[ENH] owheatmap: Add Split By combo box
  • Loading branch information
ajdapretnar authored Nov 29, 2019
2 parents b08fc7e + f56c634 commit c58eaa4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 143 deletions.
227 changes: 84 additions & 143 deletions Orange/widgets/visualize/owheatmap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import itertools

from collections import defaultdict, namedtuple
from collections import namedtuple
from types import SimpleNamespace as namespace

import numpy as np
Expand All @@ -11,7 +11,7 @@
QSizePolicy, QGraphicsScene, QGraphicsView, QGraphicsRectItem,
QGraphicsWidget, QGraphicsSimpleTextItem, QGraphicsPixmapItem,
QGraphicsGridLayout, QGraphicsLinearLayout, QGraphicsLayoutItem,
QFormLayout, QApplication
QFormLayout, QApplication, QComboBox
)
from AnyQt.QtGui import (
QFontMetrics, QPen, QPixmap, QColor, QLinearGradient, QPainter,
Expand All @@ -24,12 +24,14 @@
)
import pyqtgraph as pg

from orangewidget.utils.combobox import ComboBox

from Orange.data import Domain, Table
from Orange.data.sql.table import SqlTable
import Orange.distance

from Orange.clustering import hierarchical, kmeans
from Orange.widgets.utils.itemmodels import DomainModel
from Orange.widgets.utils.itemmodels import DomainModel, VariableListModel
from Orange.widgets.utils.stickygraphicsview import StickyGraphicsView
from Orange.widgets.utils import colorbrewer
from Orange.widgets.utils.annotated_data import (create_annotated_table,
Expand All @@ -41,108 +43,11 @@
from Orange.widgets.widget import Msg, Input, Output


def split_domain(domain, split_label):
"""Split the domain based on values of `split_label` value.
"""
groups = defaultdict(list)
for attr in domain.attributes:
groups[attr.attributes.get(split_label)].append(attr)

attr_values = [attr.attributes.get(split_label)
for attr in domain.attributes]

domains = []
for value, attrs in groups.items():
group_domain = Domain(attrs, domain.class_vars, domain.metas)

domains.append((value, group_domain))

if domains:
assert all(len(dom) == len(domains[0][1]) for _, dom in domains)

return sorted(domains, key=lambda t: attr_values.index(t[0]))


def vstack_by_subdomain(data, sub_domains):
domain = sub_domains[0]
newtable = Table(domain)

for sub_dom in sub_domains:
sub_data = data.transform(sub_dom)
# TODO: improve O(N ** 2)
newtable.extend(sub_data)

return newtable


def select_by_class(data, class_):
indices = select_by_class_indices(data, class_)
return data[indices]


def select_by_class_indices(data, class_):
col, _ = data.get_column_view(data.domain.class_var)
return col == class_


def group_by_unordered(iterable, key):
groups = defaultdict(list)
for item in iterable:
groups[key(item)].append(item)
return groups.items()


def barycenter(a, axis=0):
assert 0 <= axis < 2
a = np.asarray(a)
N = a.shape[axis]
tileshape = [1 if i != axis else a.shape[i] for i in range(a.ndim)]
xshape = list(a.shape)
xshape[axis] = 1
X = np.tile(np.reshape(np.arange(N), tileshape), xshape)
amin = np.nanmin(a, axis=axis, keepdims=True)
weights = a - amin
weights[np.isnan(weights)] = 0
wsum = np.sum(weights, axis=axis)
mask = wsum <= np.finfo(float).eps
if axis == 1:
weights[mask, :] = 1
else:
weights[:, mask] = 1

return np.average(X, weights=weights, axis=axis)


def kmeans_compress(X, k=50):
km = kmeans.KMeans(n_clusters=k, n_init=5, random_state=42)
return km.get_model(X)


def candidate_split_labels(data):
"""
Return candidate labels on which we can split the data.
"""
groups = defaultdict(list)
for attr in data.domain.attributes:
for item in attr.attributes.items():
groups[item].append(attr)

by_keys = defaultdict(list)
for (key, _), attrs in groups.items():
by_keys[key].append(attrs)

# Find the keys for which all values have the same number
# of attributes.
candidates = []
for key, groups in by_keys.items():
count = len(groups[0])
if all(len(attrs) == count for attrs in groups) and \
len(groups) > 1 and count > 1:
candidates.append(key)

return candidates


def leaf_indices(tree):
return [leaf.value.index for leaf in hierarchical.leaves(tree)]

Expand Down Expand Up @@ -384,9 +289,23 @@ def cluster_ord(self):
[name for name, _, in _color_palettes].index("Blue-Yellow")


def cbselect(cb: QComboBox, value, role: Qt.ItemDataRole = Qt.EditRole) -> None:
"""
Find and select the `value` in the `cb` QComboBox.
Parameters
----------
cb: QComboBox
value: Any
role: Qt.ItemDataRole
The data role in the combo box model to match value against
"""
cb.setCurrentIndex(cb.findData(value, role))


class OWHeatMap(widget.OWWidget):
name = "Heat Map"
description = "Plot a heat map for a pair of attributes."
description = "Plot a data matrix heatmap."
icon = "icons/Heatmap.svg"
priority = 260
keywords = []
Expand Down Expand Up @@ -418,6 +337,8 @@ class Outputs:
legend = settings.Setting(True)
# Annotations
annotation_var = settings.ContextSetting(None)
split_by_var = settings.ContextSetting(None)

# Stored color palette settings
color_settings = settings.Setting(None)
user_palettes = settings.Setting([])
Expand Down Expand Up @@ -542,6 +463,33 @@ def __init__(self):
cluster_box, self, "row_clustering", "Rows",
callback=self.update_clustering_examples)

box = gui.vBox(self.controlArea, "Split By")

self.row_split_model = DomainModel(
placeholder="(None)",
valid_types=(Orange.data.DiscreteVariable,),
parent=self,
)
self.row_split_cb = cb = ComboBox(
enabled=not self.merge_kmeans,
sizeAdjustPolicy=ComboBox.AdjustToMinimumContentsLengthWithIcon,
minimumContentsLength=14,
toolTip="Split the heatmap vertically by a categorical column"
)
self.row_split_cb.setModel(self.row_split_model)
self.connect_control(
"split_by_var", lambda value, cb=cb: cbselect(cb, value)
)
self.connect_control(
"merge_kmeans", self.row_split_cb.setDisabled
)
self.split_by_var = None

self.row_split_cb.activated.connect(
self.__on_split_rows_activated
)
box.layout().addWidget(self.row_split_cb)

box = gui.vBox(self.controlArea, 'Annotation && Legends')

gui.checkBox(box, self, 'legend', 'Show legend',
Expand Down Expand Up @@ -626,6 +574,8 @@ def clear(self):
self.merge_indices = None
self.annotation_model.set_domain(None)
self.annotation_var = None
self.row_split_model.set_domain(None)
self.split_by_var = None
self.clear_scene()
self.selected_rows = []
self.__columns_cache.clear()
Expand Down Expand Up @@ -705,7 +655,14 @@ def set_dataset(self, data=None):
if data is not None:
self.annotation_model.set_domain(self.input_data.domain)
self.annotation_var = None
self.row_split_model.set_domain(data.domain)
if data.domain.has_discrete_class:
self.split_by_var = data.domain.class_var
else:
self.split_by_var = None
self.openContext(self.input_data)
if self.split_by_var not in self.row_split_model:
self.split_by_var = None

self.update_heatmaps()
if data is not None and self.__pending_selection is not None:
Expand All @@ -715,6 +672,14 @@ def set_dataset(self, data=None):

self.unconditional_commit()

def __on_split_rows_activated(self):
self.set_split_variable(self.row_split_cb.currentData(Qt.EditRole))

def set_split_variable(self, var):
if var != self.split_by_var:
self.split_by_var = var
self.update_heatmaps()

def update_heatmaps(self):
if self.data is not None:
self.clear_scene()
Expand All @@ -727,7 +692,9 @@ def update_heatmaps(self):
elif self.merge_kmeans and len(self.data) < 3:
self.Error.not_enough_instances_k_means()
else:
self.construct_heatmaps(self.data)
self.heatmapparts = self.construct_heatmaps(
self.data, self.split_by_var
)
self.construct_heatmaps_scene(
self.heatmapparts, self.effective_data)
self.selected_rows = []
Expand All @@ -741,7 +708,7 @@ def update_merge(self):
self.update_heatmaps()
self.commit()

def _make_parts(self, data, group_var=None, group_key=None):
def _make_parts(self, data, group_var=None):
"""
Make initial `Parts` for data, split by group_var, group_key
"""
Expand All @@ -758,20 +725,11 @@ def _make_parts(self, data, group_var=None, group_key=None):
sortindices=None,
cluster=None, cluster_ordered=None)]

if group_key is not None:
col_groups = split_domain(data.domain, group_key)
assert len(col_groups) > 0
col_indices = [np.array([data.domain.index(var) for var in group])
for _, group in col_groups]
col_groups = [ColumnPart(title=name, domain=d, indices=ind,
cluster=None, cluster_ordered=None)
for (name, d), ind in zip(col_groups, col_indices)]
else:
col_groups = [
ColumnPart(
title=None, indices=slice(0, len(data.domain.attributes)),
domain=data.domain, cluster=None, cluster_ordered=None)
]
col_groups = [
ColumnPart(
title=None, indices=slice(0, len(data.domain.attributes)),
domain=data.domain, cluster=None, cluster_ordered=None)
]

minv, maxv = np.nanmin(data.X), np.nanmax(data.X)
return Parts(row_groups, col_groups, span=(minv, maxv))
Expand Down Expand Up @@ -806,11 +764,10 @@ def cluster_rows(self, data, parts):

row_groups.append(row._replace(cluster=cluster, cluster_ordered=cluster_ord))

return parts._replace(columns=parts.columns, rows=row_groups)
return parts._replace(rows=row_groups)

def cluster_columns(self, data, parts):
if len(parts.columns) > 1:
data = vstack_by_subdomain(data, [col.domain for col in parts.columns])
assert len(parts.columns) == 1, "columns split is no longer supported"
assert all(var.is_continuous for var in data.domain.attributes)

col0 = parts.columns[0]
Expand Down Expand Up @@ -839,21 +796,9 @@ def cluster_columns(self, data, parts):

col_groups = [col._replace(cluster=cluster, cluster_ordered=cluster_ord)
for col in parts.columns]
return parts._replace(columns=col_groups, rows=parts.rows)
return parts._replace(columns=col_groups)

def construct_heatmaps(self, data, split_label=None):
if split_label is not None:
groups = split_domain(data.domain, split_label)
assert len(groups) > 0
else:
groups = [("", data.domain)]

if data.domain.has_discrete_class:
group_var = data.domain.class_var
else:
group_var = None

group_label = split_label
def construct_heatmaps(self, data, group_var=None) -> 'Parts':
if self.merge_kmeans:
if self.kmeans_model is None:
effective_data = self.input_data.transform(
Expand Down Expand Up @@ -890,18 +835,14 @@ def construct_heatmaps(self, data, split_label=None):

self.__update_clustering_enable_state(effective_data)

parts = self._make_parts(effective_data, group_var, group_label)
parts = self._make_parts(effective_data, group_var)
# Restore/update the row/columns items descriptions from cache if
# available
rows_cache_key = (group_var,
self.merge_kmeans_k if self.merge_kmeans else None)
if rows_cache_key in self.__rows_cache:
parts = parts._replace(rows=self.__rows_cache[rows_cache_key].rows)

if group_label in self.__columns_cache:
parts = parts._replace(
columns=self.__columns_cache[group_label].columns)

if self.row_clustering:
assert len(effective_data) <= OWHeatMap._MaxOrderedClustering
parts = self.cluster_rows(effective_data, parts)
Expand All @@ -913,9 +854,7 @@ def construct_heatmaps(self, data, split_label=None):

# Cache the updated parts
self.__rows_cache[rows_cache_key] = parts
self.__columns_cache[group_label] = parts

self.heatmapparts = parts
return parts

def construct_heatmaps_scene(self, parts, data):
def select_row(item):
Expand Down Expand Up @@ -1521,8 +1460,10 @@ def send_report(self):
self.report_items((
("Columns:", "Clustering" if self.col_clustering else "No sorting"),
("Rows:", "Clustering" if self.row_clustering else "No sorting"),
("Split:",
self.split_by_var is not None and self.split_by_var.name),
("Row annotation",
self.annotation_var is not None and self.annotation_var.name)
self.annotation_var is not None and self.annotation_var.name),
))
self.report_plot()

Expand Down
11 changes: 11 additions & 0 deletions Orange/widgets/visualize/tests/test_owheatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def test_saved_selection(self):
self.send_signal(w.Inputs.data, iris, widget=w)
self.assertEqual(len(self.get_output(w.Outputs.selected_data)), 21)

def test_set_split_var(self):
data = Table("brown-selected")
w = self.widget
self.send_signal(self.widget.Inputs.data, data, widget=w)
self.assertIs(w.split_by_var, data.domain.class_var)
self.assertEqual(len(w.heatmapparts.rows),
len(data.domain.class_var.values))
w.set_split_variable(None)
self.assertIs(w.split_by_var, None)
self.assertEqual(len(w.heatmapparts.rows), 1)


if __name__ == "__main__":
unittest.main()

0 comments on commit c58eaa4

Please sign in to comment.