Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] t-SNE: Add Normalize data checkbox #3570

Merged
merged 2 commits into from
Feb 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions Orange/widgets/unsupervised/owtsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from AnyQt.QtWidgets import QFormLayout

from Orange.data import Table, Domain
from Orange.preprocess.preprocess import Preprocess, ApplyDomain
from Orange.projection import PCA, TSNE, TruncatedSVD
from Orange.preprocess import preprocess
from Orange.projection import PCA, TSNE
from Orange.projection.manifold import TSNEModel
from Orange.widgets import gui
from Orange.widgets.settings import Setting, SettingProvider
Expand Down Expand Up @@ -76,6 +76,7 @@ class OWtSNE(OWDataProjectionWidget):
multiscale = Setting(True)
exaggeration = Setting(1)
pca_components = Setting(20)
normalize = Setting(True)

GRAPH_CLASS = OWtSNEGraph
graph = SettingProvider(OWtSNEGraph)
Expand All @@ -85,7 +86,7 @@ class OWtSNE(OWDataProjectionWidget):
Running, Finished, Waiting, Paused = 1, 2, 3, 4

class Outputs(OWDataProjectionWidget.Outputs):
preprocessor = Output("Preprocessor", Preprocess)
preprocessor = Output("Preprocessor", preprocess.Preprocess)

class Error(OWDataProjectionWidget.Error):
not_enough_rows = Msg("Input data needs at least 2 rows")
Expand Down Expand Up @@ -143,15 +144,25 @@ def _add_controls_start_box(self):
sbp = gui.hBox(self.controlArea, False, addToLayout=False)
gui.hSlider(
sbp, self, "pca_components", minValue=2, maxValue=50, step=1,
callback=self._params_changed
callback=self._invalidate_pca_projection
)
form.addRow("PCA components:", sbp)

self.normalize_cbx = gui.checkBox(
box, self, "normalize", "Normalize data",
callback=self._invalidate_pca_projection,
)
form.addRow(self.normalize_cbx)

box.layout().addLayout(form)

gui.separator(box, 10)
self.runbutton = gui.button(box, self, "Run", callback=self._toggle_run)

def _invalidate_pca_projection(self):
self.pca_data = None
self._params_changed()

def _params_changed(self):
self.__state = OWtSNE.Finished
self.__set_update_loop(None)
Expand Down Expand Up @@ -215,12 +226,32 @@ def stop(self):
def resume(self):
self.__set_update_loop(self.tsne_iterator)

def set_data(self, data: Table):
super().set_data(data)

if data is not None:
# PCA doesn't support normalization on sparse data, as this would
# require centering and normalizing the matrix
self.normalize_cbx.setDisabled(data.is_sparse())
if data.is_sparse():
self.normalize = False
self.normalize_cbx.setToolTip(
"Data normalization is not supported on sparse matrices."
lanzagar marked this conversation as resolved.
Show resolved Hide resolved
)
else:
self.normalize_cbx.setToolTip("")

def pca_preprocessing(self):
if self.pca_data is not None and \
self.pca_data.X.shape[1] == self.pca_components:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this check because I changed line 146 to invalidate the PCA projection, therefore it will be set to None whenever the number of components changes.

"""Perform PCA preprocessing before passing off the data to t-SNE."""
if self.pca_data is not None:
return
cls = TruncatedSVD if self.data.is_sparse() else PCA
projector = cls(n_components=self.pca_components, random_state=0)

projector = PCA(n_components=self.pca_components, random_state=0)
# If the normalization box is ticked, we'll add the `Normalize`
# preprocessor to PCA
if self.normalize:
projector.preprocessors += (preprocess.Normalize(),)

model = projector(self.data)
self.pca_data = model(self.data)

Expand Down Expand Up @@ -343,7 +374,7 @@ def _get_projection_data(self):
def send_preprocessor(self):
prep = None
if self.data is not None and self.projection is not None:
prep = ApplyDomain(self.projection.domain, self.projection.name)
prep = preprocess.ApplyDomain(self.projection.domain, self.projection.name)
self.Outputs.preprocessor.send(prep)

def clear(self):
Expand Down
64 changes: 47 additions & 17 deletions Orange/widgets/unsupervised/tests/test_owtsne.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import unittest
from unittest.mock import patch
import numpy as np

from AnyQt.QtTest import QSignalSpy

from Orange.data import DiscreteVariable, ContinuousVariable, Domain, Table
from Orange.preprocess import Preprocess
from Orange.preprocess import Preprocess, Normalize
from Orange.projection.manifold import TSNE
from Orange.widgets.tests.base import (
WidgetTest, WidgetOutputsTestMixin, ProjectionWidgetTestMixin
Expand Down Expand Up @@ -50,9 +49,9 @@ def optimize(*_, **__):
self.empty_domain = Domain([], class_vars=self.class_var)

def tearDown(self):
self.reset_tsne()
self.restore_mocked_functions()

def reset_tsne(self):
def restore_mocked_functions(self):
owtsne.TSNE.fit = self._fit
owtsne.TSNEModel.transform = self._transform
owtsne.TSNEModel.optimize = self._optimize
Expand Down Expand Up @@ -113,21 +112,26 @@ def test_attr_models(self):
self.assertIn(var, controls.attr_shape.model())

def test_output_preprocessor(self):
self.reset_tsne()
# To test the validity of the preprocessor, we'll have to actually
# compute the projections
self.restore_mocked_functions()

self.send_signal(self.widget.Inputs.data, self.data)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(20000))
self.wait_until_stop_blocking(wait=20000)
output_data = self.get_output(self.widget.Outputs.annotated_data)

# We send the same data to the widget, we expect the point locations to
# be fairly close to their original ones
pp = self.get_output(self.widget.Outputs.preprocessor)
self.assertIsInstance(pp, Preprocess)
transformed = pp(self.data)
self.assertIsInstance(transformed, Table)
self.assertEqual(transformed.X.shape, (len(self.data), 2))
output = self.get_output(self.widget.Outputs.annotated_data)
np.testing.assert_allclose(transformed.X, output.metas[:, :2],
rtol=1, atol=1)
self.assertEqual([a.name for a in transformed.domain.attributes],
[m.name for m in output.domain.metas[:2]])

transformed_data = pp(self.data)
self.assertIsInstance(transformed_data, Table)
self.assertEqual(transformed_data.X.shape, (len(self.data), 2))
np.testing.assert_allclose(transformed_data.X, output_data.metas[:, :2],
rtol=1, atol=3)
self.assertEqual([a.name for a in transformed_data.domain.attributes],
[m.name for m in output_data.domain.metas[:2]])

def test_multiscale_changed(self):
self.assertFalse(self.widget.controls.multiscale.isChecked())
Expand All @@ -140,6 +144,32 @@ def test_multiscale_changed(self):
self.assertTrue(w.controls.multiscale.isChecked())
self.assertFalse(w.perplexity_spin.isEnabled())

def test_normalize_data(self):
# Normalization should be checked by default
self.assertTrue(self.widget.controls.normalize.isChecked())
with patch("Orange.preprocess.preprocess.Normalize", wraps=Normalize) as normalize:
self.send_signal(self.widget.Inputs.data, self.data)
self.assertTrue(self.widget.controls.normalize.isEnabled())
normalize.assert_called_once()

# Disable checkbox
self.widget.controls.normalize.setChecked(False)
self.assertFalse(self.widget.controls.normalize.isChecked())
with patch("Orange.preprocess.preprocess.Normalize", wraps=Normalize) as normalize:
self.send_signal(self.widget.Inputs.data, self.data)
self.assertTrue(self.widget.controls.normalize.isEnabled())
normalize.assert_not_called()

# Normalization shouldn't work on sparse data
self.widget.controls.normalize.setChecked(True)
self.assertTrue(self.widget.controls.normalize.isChecked())

sparse_data = self.data.to_sparse()
with patch("Orange.preprocess.preprocess.Normalize", wraps=Normalize) as normalize:
self.send_signal(self.widget.Inputs.data, sparse_data)
self.assertFalse(self.widget.controls.normalize.isEnabled())
normalize.assert_not_called()


if __name__ == '__main__':
unittest.main()