Skip to content

Commit

Permalink
Merge pull request #2149 from jerneju/index-corresondence
Browse files Browse the repository at this point in the history
[FIX] Correspondence: Prevent crashing when cont attr has one value
  • Loading branch information
lanzagar authored Apr 21, 2017
2 parents e63c72b + 4af8ffd commit 4924983
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 31 deletions.
88 changes: 57 additions & 31 deletions Orange/widgets/unsupervised/owcorrespondence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys

import numpy
from collections import namedtuple, OrderedDict

import numpy as np

from AnyQt.QtWidgets import QListView, QApplication
from AnyQt.QtGui import QBrush, QColor, QPainter
Expand Down Expand Up @@ -109,7 +111,8 @@ def set_data(self, data):
self.varlist[:] = [var for var in data.domain.variables
if var.is_discrete]
self.selected_var_indices = [0, 1][:len(self.varlist)]
self.component_x, self.component_y = 0, 1
self.component_x = 0
self.component_y = int(len(self.varlist[self.selected_var_indices[-1]].values) > 1)
self.openContext(data)
self._restore_selection()
# self._invalidate()
Expand Down Expand Up @@ -193,14 +196,27 @@ def update_XY(self):
return rfs

def _setup_plot(self):
def get_minmax(points):
minmax = [float('inf'),
float('-inf'),
float('inf'),
float('-inf')]
for pp in points:
for p in pp:
minmax[0] = min(p[0], minmax[0])
minmax[1] = max(p[0], minmax[1])
minmax[2] = min(p[1], minmax[2])
minmax[3] = max(p[1], minmax[3])
return minmax

self.plot.clear()
points = self.ca
variables = self.selected_vars()
colors = colorpalette.ColorPaletteGenerator(len(variables))

p_axes = self._p_axes()

if points == None:
if points is None:
return

if len(variables) == 2:
Expand All @@ -210,10 +226,19 @@ def _setup_plot(self):
else:
points = self.ca.row_factors[:, p_axes]
counts = [len(var.values) for var in variables]
range_indices = numpy.cumsum([0] + counts)
range_indices = np.cumsum([0] + counts)
ranges = zip(range_indices, range_indices[1:])
points = [points[s:e] for s, e in ranges]

minmax = get_minmax(points)

margin = abs(minmax[0] - minmax[1])
margin = margin * 0.05 if margin > 1e-10 else 1
self.plot.setXRange(minmax[0] - margin, minmax[1] + margin)
margin = abs(minmax[2] - minmax[3])
margin = margin * 0.05 if margin > 1e-10 else 1
self.plot.setYRange(minmax[2] - margin, minmax[3] + margin)

for i, (v, points) in enumerate(zip(variables, points)):
color_outline = colors[i]
color_outline.setAlpha(200)
Expand All @@ -222,7 +247,7 @@ def _setup_plot(self):
item = ScatterPlotItem(
x=points[:, 0], y=points[:, 1], brush=QBrush(color),
pen=pg.mkPen(color_outline.darker(120), width=1.5),
size=numpy.full((points.shape[0],), 10.1),
size=np.full((points.shape[0],), 10.1),
)
self.plot.addItem(item)

Expand All @@ -232,10 +257,10 @@ def _setup_plot(self):
item.setPos(point[0], point[1])

inertia = self.ca.inertia_of_axis()
if numpy.sum(inertia) == 0:
if np.sum(inertia) == 0:
inertia = 100 * inertia
else:
inertia = 100 * inertia / numpy.sum(inertia)
inertia = 100 * inertia / np.sum(inertia)

ax = self.plot.getAxis("bottom")
ax.setLabel("Component {} ({:.1f}%)"
Expand All @@ -251,10 +276,10 @@ def _update_info(self):
fmt = ("Axis 1: {:.2f}\n"
"Axis 2: {:.2f}")
inertia = self.ca.inertia_of_axis()
if numpy.sum(inertia) == 0:
if np.sum(inertia) == 0:
inertia = 100 * inertia
else:
inertia = 100 * inertia / numpy.sum(inertia)
inertia = 100 * inertia / np.sum(inertia)

ax1, ax2 = self._p_axes()
self.infotext.setText(fmt.format(inertia[ax1], inertia[ax2]))
Expand Down Expand Up @@ -293,9 +318,9 @@ def burt_table(data, variables):
"""
values = [(var, value) for var in variables for value in var.values]

table = numpy.zeros((len(values), len(values)))
table = np.zeros((len(values), len(values)))
counts = [len(attr.values) for attr in variables]
offsets = numpy.r_[0, numpy.cumsum(counts)]
offsets = np.r_[0, np.cumsum(counts)]

for i in range(len(variables)):
for j in range(i + 1):
Expand All @@ -318,41 +343,42 @@ def correspondence(A):
"""
:param numpy.ndarray A:
"""
A = numpy.asarray(A)
A = np.asarray(A)

total = numpy.sum(A)
total = np.sum(A)
if total > 0:
corr_mat = A / total
else:
# ???
corr_mat = A

col_sum = numpy.sum(corr_mat, axis=0, keepdims=True)
row_sum = numpy.sum(corr_mat, axis=1, keepdims=True)
col_sum = np.sum(corr_mat, axis=0, keepdims=True)
row_sum = np.sum(corr_mat, axis=1, keepdims=True)
E = row_sum * col_sum

D_r, D_c = row_sum.ravel() ** -1, col_sum.ravel() ** -1
D_r, D_c = numpy.nan_to_num(D_r), numpy.nan_to_num(D_c)
D_r, D_c = np.nan_to_num(D_r), np.nan_to_num(D_c)

def gsvd(M, Wu, Wv):
def gsvd(M, wu, wv):
assert len(M.shape) == 2
assert len(Wu.shape) == 1 and len(Wv.shape) == 1
Wu_sqrt = numpy.sqrt(Wu)
Wv_sqrt = numpy.sqrt(Wv)
B = numpy.c_[Wu_sqrt] * M * numpy.r_[Wv_sqrt]
Ub, D, Vb = numpy.linalg.svd(B, full_matrices=False)
U = numpy.c_[Wu_sqrt ** -1] * Ub
V = (numpy.c_[Wv_sqrt ** -1] * Vb.T).T
assert len(wu.shape) == 1 and len(wv.shape) == 1
Wu_sqrt = np.sqrt(wu)
Wv_sqrt = np.sqrt(wv)
B = np.c_[Wu_sqrt] * M * np.r_[Wv_sqrt]
Ub, D, Vb = np.linalg.svd(B, full_matrices=False)
U = np.c_[Wu_sqrt ** -1] * Ub
V = (np.c_[Wv_sqrt ** -1] * Vb.T).T
return U, D, V

U, D, V = gsvd(corr_mat - E, D_r, D_c)

F = numpy.c_[D_r] * U * D
G = numpy.c_[D_c] * V.T * D
F = np.c_[D_r] * U * D
G = np.c_[D_c] * V.T * D

return CA(U, D, V, F, G, row_sum, col_sum)
if F.shape == (1, 1) and F[0, 0] == 0:
F[0, 0] = 1

from collections import namedtuple, OrderedDict
return CA(U, D, V, F, G, row_sum, col_sum)

CA = namedtuple("CA", ["U", "D", "V", "row_factors", "col_factors",
"row_sums", "column_sums"])
Expand All @@ -366,10 +392,10 @@ def column_inertia(self):
return self.column_sums.T * (self.col_factors ** 2)

def inertia_of_axis(self):
return numpy.sum(self.row_inertia(), axis=0)
return np.sum(self.row_inertia(), axis=0)


def test_main(argv=None):
def main(argv=None):
import sip
if argv is None:
argv = sys.argv[1:]
Expand All @@ -392,4 +418,4 @@ def test_main(argv=None):
return rval

if __name__ == "__main__":
sys.exit(test_main())
sys.exit(main())
14 changes: 14 additions & 0 deletions Orange/widgets/unsupervised/tests/test_owcorrespondence.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ def test_data_values_in_column(self):
"klkk"
)))
self.send_signal("Data", table)

def test_data_one_value_zero(self):
"""
Check that the widget does not crash on discrete attributes with only
one value.
GH-2149
"""
table = Table(
Domain(
[DiscreteVariable("a", values=["0"])]
),
[(0,), (0,), (0,)]
)
self.send_signal("Data", table)

0 comments on commit 4924983

Please sign in to comment.