Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Scatter Plot Graph: max discrete values colors and shape #2804

Merged
merged 3 commits into from
Dec 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions Orange/widgets/visualize/owscatterplotgraph.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from collections import Counter
import sys
import itertools
from xml.sax.saxutils import escape
from math import log10, floor, ceil

import numpy as np
import scipy.sparse as sp
from scipy.stats import linregress

from AnyQt.QtCore import Qt, QObject, QEvent, QRectF, QPointF, QSize
Expand All @@ -25,7 +25,7 @@
from Orange.widgets import gui
from Orange.widgets.utils import classdensity, get_variable_values_sorted
from Orange.widgets.utils.colorpalette import (ColorPaletteGenerator,
ContinuousPaletteGenerator)
ContinuousPaletteGenerator, DefaultRGBColors)
from Orange.widgets.utils.plot import \
OWPalette, OWPlotGUI, SELECT, PANNING, ZOOMING
from Orange.widgets.utils.scaling import ScaleScatterPlotData
Expand All @@ -35,6 +35,7 @@
# TODO Move utility classes to another module, so they can be used elsewhere

SELECTION_WIDTH = 5
MAX = 11 # maximum number of colors or shapes (including Other)

class PaletteItemSample(ItemSample):
"""A color strip to insert into legends for discretized continuous values"""
Expand Down Expand Up @@ -485,7 +486,7 @@ class OWScatterPlotGraph(gui.OWComponent, ScaleScatterPlotData):
show_reg_line = Setting(False)
resolution = 256

CurveSymbols = np.array("o x t + d s ?".split())
CurveSymbols = np.array("o x t + d s t2 t3 p h star ?".split())
MinShapeSize = 6
DarkerValue = 120
UnknownColor = (168, 50, 168)
Expand Down Expand Up @@ -818,7 +819,8 @@ def get_color(self):
colors = self.attr_color.colors
if self.attr_color.is_discrete:
self.discrete_palette = ColorPaletteGenerator(
number_of_colors=len(colors), rgb_colors=colors)
number_of_colors=min(len(colors), MAX), rgb_colors=colors if len(colors) <= MAX
else DefaultRGBColors)
else:
self.continuous_palette = ContinuousPaletteGenerator(*colors)
return self.attr_color
Expand Down Expand Up @@ -853,6 +855,35 @@ def make_pen(color, width):
brush = [QBrush(QColor(255, 255, 255, 0))] * self.n_points
return pen, brush

def _reduce_values(self, attr):
"""
If discrete variable has more than maximium allowed values,
less used values are joined as "Other"
"""
c_data = self.data.get_column_view(attr)[0][self.valid_data]
if attr.is_continuous or len(attr.values) <= MAX:
return None, c_data
values_to_replace = Counter(c_data)
values_to_replace = sorted(
values_to_replace, key=values_to_replace.get, reverse=True
)
return values_to_replace, c_data

def _get_values(self, attr):
if len(attr.values) <= MAX:
return attr.values
values_to_replace, _ = self._reduce_values(attr)
return [attr.values[int(i)] for i in values_to_replace
if not np.isnan(i)][:MAX - 1] + ["Other"]

def _get_data(self, attr):
values_to_replace, c_data = self._reduce_values(attr)
if values_to_replace is not None:
c_data_2 = c_data.copy()
for i, v in enumerate(values_to_replace):
c_data[c_data_2 == v] = i if i < MAX - 1 else MAX - 1
return c_data

def compute_colors(self, keep_colors=False):
if not keep_colors:
self.pen_colors = self.brush_colors = None
Expand Down Expand Up @@ -880,7 +911,7 @@ def make_pen(color, width):
* self.n_points
return pen, brush

c_data = self.data.get_column_view(self.attr_color)[0][self.valid_data]
c_data = self._get_data(self.attr_color)
if self.attr_color.is_continuous:
if self.pen_colors is None:
self.scale = DiscretizedScale(np.nanmin(c_data), np.nanmax(c_data))
Expand Down Expand Up @@ -992,18 +1023,12 @@ def update_labels(self):
for label, text in zip(self.labels, label_data):
label.setText(text, black)

def get_shape(self):
if self.attr_shape is None or \
len(self.attr_shape.values) > len(self.CurveSymbols):
return None
return self.attr_shape

def compute_symbols(self):
self.master.Information.missing_shape.clear()
if self.get_shape() is None:
if self.attr_shape is None:
shape_data = self.CurveSymbols[np.zeros(self.n_points, dtype=int)]
else:
shape_data = self.data.get_column_view(self.attr_shape)[0][self.valid_data]
shape_data = self._get_data(self.attr_shape)
nans = np.isnan(shape_data)
if np.any(nans):
shape_data[nans] = len(self.CurveSymbols) - 1
Expand Down Expand Up @@ -1057,12 +1082,12 @@ def make_legend(self):
def make_color_legend(self):
if self.attr_color is None:
return
use_shape = self.get_shape() == self.get_color()
use_shape = self.attr_shape == self.get_color()
if self.attr_color.is_discrete:
if not self.legend:
self.create_legend()
palette = self.discrete_palette
for i, value in enumerate(self.attr_color.values):
for i, value in enumerate(self._get_values(self.attr_color)):
color = QColor(*palette.getRGB(i))
brush = color.lighter(self.DarkerValue)
self.legend.addItem(
Expand All @@ -1080,14 +1105,13 @@ def make_color_legend(self):
legend.setGeometry(label.boundingRect())

def make_shape_legend(self):
shape = self.get_shape()
if shape is None or shape == self.get_color():
if self.attr_shape is None or self.attr_shape == self.get_color():
return
if not self.legend:
self.create_legend()
color = QColor(0, 0, 0)
color.setAlpha(self.alpha_value)
for i, value in enumerate(self.attr_shape.values):
for i, value in enumerate(self._get_values(self.attr_shape)):
self.legend.addItem(
ScatterPlotItem(pen=color, brush=color, size=10,
symbol=self.CurveSymbols[i]), escape(value))
Expand Down
26 changes: 26 additions & 0 deletions Orange/widgets/visualize/tests/test_owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from AnyQt.QtWidgets import QToolTip

from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
from Orange.widgets.visualize.owscatterplotgraph import MAX
from Orange.widgets.widget import AttributeList
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin, datasets
from Orange.widgets.visualize.owscatterplot import \
Expand Down Expand Up @@ -593,6 +594,31 @@ def test_tooltip(self):
self.assertFalse(graph.help_event(event))
self.assertEqual(show_text.call_count, 0)

def test_many_discrete_values(self):
"""
Do not show all discrete values if there are too many.
Also test for values with a nan.
GH-2804
"""
def prepare_data():
data = Table("iris")
values = list(range(15))
class_var = DiscreteVariable("iris5", values=[str(v) for v in values])
data = data.transform(Domain(attributes=data.domain.attributes, class_vars=[class_var]))
data.Y = np.array(values * 10, dtype=float)
return data

def assert_equal(data, max):
self.send_signal(self.widget.Inputs.data, data)
pen_data, brush_data = self.widget.graph.compute_colors()
self.assertEqual(max, len(np.unique([id(p) for p in pen_data])), )

assert_equal(prepare_data(), MAX)
# data with nan value
data = prepare_data()
data.Y[42] = np.nan
assert_equal(data, MAX + 1)


if __name__ == "__main__":
import unittest
Expand Down