Skip to content

Commit

Permalink
heatmap: Add legend for color annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
ales-erjavec committed Mar 5, 2020
1 parent 0cb4735 commit 6336200
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 33 deletions.
238 changes: 206 additions & 32 deletions Orange/widgets/visualize/utils/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from itertools import chain, zip_longest

from typing import (
Optional, List, NamedTuple, Sequence, Tuple, Dict, Union,
Optional, List, NamedTuple, Sequence, Tuple, Dict, Union, Iterable
)

import numpy as np
Expand All @@ -15,7 +15,8 @@
from AnyQt.QtWidgets import (
QGraphicsWidget, QSizePolicy, QGraphicsGridLayout, QGraphicsRectItem,
QApplication, QGraphicsSceneMouseEvent, QGraphicsLinearLayout,
QGraphicsItem, QGraphicsSimpleTextItem
QGraphicsItem, QGraphicsSimpleTextItem, QGraphicsLayout,
QGraphicsLayoutItem
)

import pyqtgraph as pg
Expand Down Expand Up @@ -309,6 +310,8 @@ def __init__(self, parent=None, **kwargs):
self.col_dendrograms = [] # type: List[Optional[DendrogramWidget]]
self.row_dendrograms = [] # type: List[Optional[DendrogramWidget]]
self.right_side_colors = [] # type: List[Optional[GraphicsPixmapWidget]]
self.heatmap_colormap_legend = None
self.bottom_legend_container = None
self.__layout = GridLayout()
self.__layout.setSpacing(self.__spacing)
self.setLayout(self.__layout)
Expand Down Expand Up @@ -338,6 +341,8 @@ def clear(self):
self.col_dendrograms = []
self.row_dendrograms = []
self.right_side_colors = []
self.heatmap_colormap_legend = None
self.bottom_legend_container = None
self.parts = None
self.updateGeometry()

Expand Down Expand Up @@ -491,6 +496,7 @@ def setHeatmaps(self, parts: 'Parts') -> None:
rowauxsidecolor = GraphicsPixmapWidget(
parent=self, visible=False,
scaleContents=True, aspectMode=Qt.IgnoreAspectRatio,
sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Ignored)
)
rowauxsidecolor.setVisible(False)
grid.addItem(rowauxsidecolor, Row0 + i, RightLabelColumn - 1)
Expand Down Expand Up @@ -532,6 +538,17 @@ def setHeatmaps(self, parts: 'Parts') -> None:
col_annotation_widgets.append(labelslist)
col_annotation_widgets_bottom.append(labelslist)

row_color_annotation_header = QGraphicsSimpleTextItem("", self)
row_color_annotation_header.rotate(-90)

grid.addItem(SimpleLayoutItem(
row_color_annotation_header, anchor=(0, 1), resizeContents=True,
aspectMode=Qt.KeepAspectRatio,
sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Preferred),
),
self.TopLabelsRow, RightLabelColumn - 1,
)

legend = GradientLegendWidget(
parts.span[0], parts.span[1],
colormap,
Expand All @@ -542,6 +559,14 @@ def setHeatmaps(self, parts: 'Parts') -> None:
)
legend.setMaximumWidth(300)
grid.addItem(legend, self.LegendRow, self.LegendCol)
legend_container = QGraphicsWidget(
visible=False,
sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Fixed)
)
legend_container.setLayout(QGraphicsLinearLayout())
legend_container.layout().setContentsMargins(0, 0, 0, 0)
grid.addItem(legend_container, BottomLabelsRow + 1, Col0 + 1, 1, M * 2,
alignment=Qt.AlignRight)

self.heatmap_widget_grid = heatmap_widgets
self.row_annotation_widgets = row_annotation_widgets
Expand All @@ -551,6 +576,8 @@ def setHeatmaps(self, parts: 'Parts') -> None:
self.col_dendrograms = column_dendrograms
self.row_dendrograms = row_dendrograms
self.right_side_colors = right_side_colors
self.heatmap_colormap_legend = legend
self.bottom_legend_container = legend_container
self.parts = parts
self.__selection_manager.set_heatmap_widgets(heatmap_widgets)

Expand All @@ -561,9 +588,11 @@ def legendVisible(self) -> bool:
def setLegendVisible(self, visible: bool) -> None:
"""Set colormap legend visible state."""
self.__legendVisible = visible
item = self.__layout.itemAt(self.LegendRow, self.LegendCol)
if isinstance(item, GradientLegendWidget):
item.setVisible(visible)
legends = [
self.heatmap_colormap_legend,
self.bottom_legend_container
]
apply_all(filter(None, legends), lambda item: item.setVisible(visible))

legendVisible_ = Property(bool, legendVisible, setLegendVisible)

Expand Down Expand Up @@ -634,57 +663,79 @@ def setRowLabelsVisible(self, visible: bool):
for widget in self.row_annotation_widgets:
widget.setVisible(visible)

def setRowSideColorAnnotations(self, colors: Optional[np.ndarray], name=""):
def setRowSideColorAnnotations(
self, data: np.ndarray, colormap: ColorMap = None, name=""
):
"""
Set an optional row side color annotations.
Parameters
----------
colors: An (N, 3) uint8 array, optional
An array specifying the rgb color components for every row.
If None then the side color annotations are cleared.
data: Optional[np.ndarray]
A sequence such that it is accepted by `colormap.apply`. If None
then the color annotations are cleared.
colormap: ColorMap
name: str
Name/title for the annotation stripe.
Name/title for the annotation column.
"""
items = self.right_side_colors
col = self.Col0 + 2 * len(self.parts.columns)
nameitem = self.__layout.itemAt(self.TopLabelsRow, col)
if colors is None:
apply_all(filter(None, items), lambda a: a.setVisible(False))
apply_all(filter(None, items), lambda a: a.setPreferredWidth(-1))
if nameitem is not None:
nameitem.setPreferredWidth(0)
nameitem.item.setVisible(False)
layout = self.__layout
nameitem = layout.itemAt(self.TopLabelsRow, col)
width = QFontMetrics(self.font()).lineSpacing()
legend_container = self.bottom_legend_container
layout_clear(legend_container.layout())

def set_hidden(item: GraphicsPixmapWidget):
item.setVisible(False)
item.setPreferredWidth(-1)

def set_visible(item: GraphicsPixmapWidget):
item.setVisible(True)
item.setPreferredWidth(width)

if data is None:
apply_all(filter(None, items), set_hidden)
layout.setColumnMaximumWidth(col, 0)
nameitem.item.setVisible(False)
nameitem.updateGeometry()
legend_container.setVisible(False)
return
fm = QFontMetrics(self.font())
width = fm.lineSpacing()
else:
apply_all(filter(None, items), set_visible)
layout.setColumnMaximumWidth(col, FLT_MAX)
legend_container.setVisible(True)

parts = self.parts.rows
nrows = sum(p.size for p in parts)
assert len(colors) == nrows
assert len(data) == nrows
for p, item in zip(parts, items):
if item is not None:
subset = colors[p.normalized_indices]
subset = data[p.normalized_indices]
subset = colormap.apply(subset)
img = qimage_from_array(subset.reshape((-1, 1, subset.shape[-1])))
item.setPixmap(img)
item.setVisible(True)
item.setPreferredWidth(width)

if nameitem is None:
item = QGraphicsSimpleTextItem(name, self)
item.rotate(-90)
nameitem = SimpleLayoutItem(
item, anchor=(0, 1), resizeContents=True,
aspectMode=Qt.KeepAspectRatio,
)
nameitem.setSizePolicy(QSizePolicy.Maximum, QSizePolicy.Preferred)
# nameitem spans all header rows
self.__layout.addItem(nameitem, 0, col, self.TopLabelsRow + 1, 1)

nameitem.item.setText(name)
nameitem.item.setVisible(True)
nameitem.setPreferredWidth(width)
nameitem.updateGeometry()

container = legend_container.layout()
if isinstance(colormap, CategoricalColorMap):
legend = CategoricalColorLegend(
colormap, title=name,
orientation=Qt.Horizontal,
sizePolicy=QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum),
visible=self.__legendVisible,
)
container.addItem(legend)
elif isinstance(colormap, GradientColorMap):
legend = GradientLegendWidget(*colormap.span, colormap)
container.addItem(legend)

def headerGeometry(self) -> QRectF:
"""Return the 'header' geometry.
Expand Down Expand Up @@ -1111,6 +1162,129 @@ def changeEvent(self, event: QEvent) -> None:
super().changeEvent(event)


class CategoricalColorLegend(QGraphicsWidget):
def __init__(
self, colormap: CategoricalColorMap, title="",
orientation=Qt.Vertical, parent=None, **kwargs,
) -> None:
self.__colormap = colormap
self.__title = title
self.__names = colormap.names
self.__layout = QGraphicsGridLayout()
self.__layout.setSpacing(2)
self.__orientation = orientation
kwargs.setdefault(
"sizePolicy", QSizePolicy(QSizePolicy.Maximum, QSizePolicy.Maximum)
)
super().__init__(None, **kwargs)
self.setLayout(self.__layout)
self._setup()

if parent is not None:
self.setParent(parent)

def setOrientation(self, orientation):
if self.__orientation != orientation:
self._clear()
self._setup()

def orientation(self):
return self.__orientation

def _clear(self):
items = reversed(list(layout_items(self.__layout)))
for item in items:
self.__layout.removeItem(item)
for item in items:
if isinstance(item, SimpleLayoutItem):
remove_item(item.item)

def _setup(self):
# setup the layout
colors = self.__colormap.colortable
names = self.__colormap.names
title = self.__title
layout = self.__layout
assert layout.count() == 0
font = self.font()
fm = QFontMetrics(font)
size = fm.width("X")
start = 0
if title:
start = 1
item = QGraphicsSimpleTextItem(title)
item.setFont(font)
headeritem = QGraphicsSimpleTextItem(title)
headeritem.setFont(font)
else:
headeritem = None

items = []
for i, (color, label) in enumerate(zip(colors, names), start=start):
colitem = QGraphicsRectItem(0, 0, size, size)
colitem.setBrush(QColor(*color))
textitem = QGraphicsSimpleTextItem()
textitem.setFont(font)
textitem.setText(label)
items.append((colitem, textitem))

def centered(item):
return SimpleLayoutItem(item, anchor=(0.5, 0.5), anchorItem=(0.5, 0.5))

def addrowspan(item):
layout.addItem(centered(item), layout.rowCount(), 0, 1, 2)

def addrow(symbol, label):
row = layout.rowCount()
layout.addItem(centered(symbol), row, 0)
layout.addItem(
SimpleLayoutItem(label), row, 1,
alignment=Qt.AlignVCenter | Qt.AlignLeft
)
if self.__orientation == Qt.Vertical:
if headeritem:
addrowspan(headeritem)
apply_all(items, lambda el: addrow(*el))
else:
for sym, label in items:
layout.addItem(centered(sym), 1, layout.columnCount())
layout.addItem(SimpleLayoutItem(label), 1, layout.columnCount())
if headeritem:
layout.addItem(
centered(headeritem), 0, 0, 1, layout.columnCount())

def changeEvent(self, event: QEvent) -> None:
if event.type() == QEvent.FontChange:
self._updateFont(self.font())
super().changeEvent(event)

def _updateFont(self, font):
w = QFontMetrics(font).width("X")
for item in filter(
lambda item: isinstance(item, SimpleLayoutItem),
layout_items(self.__layout)
):
if isinstance(item.item, QGraphicsSimpleTextItem):
item.item.setFont(font)
elif isinstance(item.item, QGraphicsRectItem):
item.item.setRect(QRectF(0, 0, w, w))
item.updateGeometry()


def layout_items(layout: QGraphicsLayout) -> Iterable[QGraphicsLayoutItem]:
for item in map(layout.itemAt, range(layout.count())):
if item is not None:
yield item


def layout_clear(layout: QGraphicsLayout) -> None:
for i in reversed(range(layout.count())):
item = layout.itemAt(i)
layout.removeAt(i)
if item is not None and item.graphicsItem() is not None:
remove_item(item.graphicsItem())


class SelectionManager(QObject):
"""
Selection manager for heatmap rows
Expand Down
6 changes: 5 additions & 1 deletion Orange/widgets/visualize/utils/tests/test_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def test_widget_annotations(self):
):
w.setColumnLabelsPosition(pos)

w.setRowSideColorAnnotations(np.array([[255] * 3, [0] * 3]), "c")
w.setRowSideColorAnnotations(
np.array([0, 1]),
CategoricalColorMap(np.array([[255] * 3, [0] * 3]),
names=["a", "b"])
)
w.setRowSideColorAnnotations(None)

def test_selection(self):
Expand Down

0 comments on commit 6336200

Please sign in to comment.