Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] t-SNE: Updates 2. #3475

Merged
merged 3 commits into from
Dec 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Orange/projection/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,8 @@ def __call__(self, data: Table) -> TSNEModel:
model.name = self.name

return model

@staticmethod
def default_initialization(data, n_components=2, random_state=None):
return fastTSNE.initialization.pca(
data, n_components, random_state=random_state)
53 changes: 37 additions & 16 deletions Orange/widgets/unsupervised/owtsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from AnyQt.QtCore import Qt, QTimer
from AnyQt.QtWidgets import QFormLayout

import fastTSNE.initialization

from Orange.data import Table, Domain
from Orange.preprocess.preprocess import Preprocess, ApplyDomain
from Orange.projection import PCA, TSNE, TruncatedSVD
Expand Down Expand Up @@ -82,7 +80,7 @@ class OWtSNE(OWDataProjectionWidget):
embedding_variables_names = ("t-SNE-x", "t-SNE-y")

#: Runtime state
Running, Finished, Waiting = 1, 2, 3
Running, Finished, Waiting, Paused = 1, 2, 3, 4

class Outputs(OWDataProjectionWidget.Outputs):
preprocessor = Output("Preprocessor", Preprocess)
Expand All @@ -100,6 +98,7 @@ def __init__(self):
self.pca_data = None
self.projection = None
self.tsne_runner = None
self.tsne_iterator = None
self.__update_loop = None
# timer for scheduling updates
self.__timer = QTimer(self, singleShot=True, interval=1,
Expand All @@ -122,31 +121,42 @@ def _add_controls_start_box(self):
)

self.perplexity_spin = gui.spin(
box, self, "perplexity", 1, 500, step=1, alignment=Qt.AlignRight)
box, self, "perplexity", 1, 500, step=1, alignment=Qt.AlignRight,
callback=self._params_changed
)
form.addRow("Perplexity:", self.perplexity_spin)
self.perplexity_spin.setEnabled(not self.multiscale)
form.addRow(gui.checkBox(
box, self, "multiscale", label="Preserve global structure",
callback=self._multiscale_changed
))
self._multiscale_changed()

sbe = gui.hBox(self.controlArea, False, addToLayout=False)
gui.hSlider(
sbe, self, "exaggeration", minValue=1, maxValue=4, step=1)
sbe, self, "exaggeration", minValue=1, maxValue=4, step=1,
callback=self._params_changed
)
form.addRow("Exaggeration:", sbe)

sbp = gui.hBox(self.controlArea, False, addToLayout=False)
gui.hSlider(
sbp, self, "pca_components", minValue=2, maxValue=50, step=1)
sbp, self, "pca_components", minValue=2, maxValue=50, step=1,
callback=self._params_changed
)
form.addRow("PCA components:", sbp)

box.layout().addLayout(form)

gui.separator(box, 10)
self.runbutton = gui.button(box, self, "Run", callback=self._toggle_run)

def _params_changed(self):
self.__state = OWtSNE.Finished
self.__set_update_loop(None)

def _multiscale_changed(self):
self.perplexity_spin.setEnabled(not self.multiscale)
self._params_changed()

def check_data(self):
def error(err):
Expand Down Expand Up @@ -181,6 +191,8 @@ def _toggle_run(self):
if self.__state == OWtSNE.Running:
self.stop()
self.commit()
elif self.__state == OWtSNE.Paused:
self.resume()
else:
self.start()

Expand All @@ -191,8 +203,11 @@ def start(self):
self.__start()

def stop(self):
if self.__state == OWtSNE.Running:
self.__set_update_loop(None)
self.__state = OWtSNE.Paused
self.__set_update_loop(None)

def resume(self):
self.__set_update_loop(self.tsne_iterator)

def pca_preprocessing(self):
if self.pca_data is not None and \
Expand All @@ -208,7 +223,7 @@ def __start(self):

# We call PCA through fastTSNE because it involves scaling. Instead of
# worrying about this ourselves, we'll let the library worry for us.
initialization = fastTSNE.initialization.pca(
initialization = TSNE.default_initialization(
self.pca_data.X, n_components=2, random_state=0)

# Compute perplexity settings for multiscale
Expand All @@ -233,13 +248,14 @@ def __start(self):
)(self.pca_data)

self.tsne_runner = TSNERunner(self.projection, step_size=50)

self.__set_update_loop(self.tsne_runner.run_optimization())
self.tsne_iterator = self.tsne_runner.run_optimization()
self.__set_update_loop(self.tsne_iterator)
self.progressBarInit(processEvents=None)

def __set_update_loop(self, loop):
if self.__update_loop is not None:
self.__update_loop.close()
if self.__state in (OWtSNE.Finished, OWtSNE.Waiting):
self.__update_loop.close()
self.__update_loop = None
self.progressBarFinished(processEvents=None)

Expand All @@ -255,8 +271,10 @@ def __set_update_loop(self, loop):
else:
self.setBlocking(False)
self.setStatusMessage("")
self.runbutton.setText("Start")
self.__state = OWtSNE.Finished
if self.__state in (OWtSNE.Finished, OWtSNE.Waiting):
self.runbutton.setText("Start")
if self.__state == OWtSNE.Paused:
self.runbutton.setText("Resume")
self.__timer.stop()

def __next_step(self):
Expand All @@ -273,13 +291,16 @@ def __next_step(self):
projection, progress = next(self.__update_loop)
assert self.__update_loop is loop
except StopIteration:
self.__state = OWtSNE.Finished
self.__set_update_loop(None)
self.unconditional_commit()
except MemoryError:
self.Error.out_of_memory()
self.__state = OWtSNE.Finished
self.__set_update_loop(None)
except Exception as exc:
self.Error.optimization_error(str(exc))
self.__state = OWtSNE.Finished
self.__set_update_loop(None)
else:
self.progressBarSet(100.0 * progress, processEvents=None)
Expand Down Expand Up @@ -321,8 +342,8 @@ def send_preprocessor(self):

def clear(self):
super().clear()
self.__set_update_loop(None)
self.__state = OWtSNE.Waiting
self.__set_update_loop(None)
self.pca_data = None
self.projection = None

Expand Down
19 changes: 13 additions & 6 deletions Orange/widgets/unsupervised/tests/test_owtsne.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
import numpy as np

from AnyQt.QtTest import QSignalSpy

from Orange.data import DiscreteVariable, ContinuousVariable, Domain, Table
from Orange.preprocess import Preprocess
from Orange.projection.manifold import TSNE
Expand Down Expand Up @@ -39,7 +41,8 @@ def optimize(*_, **__):
owtsne.TSNEModel.transform = transform
owtsne.TSNEModel.optimize = optimize

self.widget = self.create_widget(OWtSNE)
self.widget = self.create_widget(OWtSNE,
stored_settings={"multiscale": False})

self.class_var = DiscreteVariable('Stage name', values=['STG1', 'STG2'])
self.attributes = [ContinuousVariable('GeneName' + str(i)) for i in range(5)]
Expand Down Expand Up @@ -110,7 +113,11 @@ def test_attr_models(self):
self.assertIn(var, controls.attr_shape.model())

def test_output_preprocessor(self):
self.reset_tsne()
self.send_signal(self.widget.Inputs.data, self.data)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(20000))
pp = self.get_output(self.widget.Outputs.preprocessor)
self.assertIsInstance(pp, Preprocess)
transformed = pp(self.data)
Expand All @@ -123,15 +130,15 @@ def test_output_preprocessor(self):
[m.name for m in output.domain.metas[:2]])

def test_multiscale_changed(self):
self.assertTrue(self.widget.controls.multiscale.isChecked())
self.assertFalse(self.widget.perplexity_spin.isEnabled())
self.widget.controls.multiscale.setChecked(False)
self.assertFalse(self.widget.controls.multiscale.isChecked())
self.assertTrue(self.widget.perplexity_spin.isEnabled())
self.widget.controls.multiscale.setChecked(True)
self.assertFalse(self.widget.perplexity_spin.isEnabled())

settings = self.widget.settingsHandler.pack_data(self.widget)
w = self.create_widget(OWtSNE, stored_settings=settings)
self.assertFalse(w.controls.multiscale.isChecked())
self.assertTrue(w.perplexity_spin.isEnabled())
self.assertTrue(w.controls.multiscale.isChecked())
self.assertFalse(w.perplexity_spin.isEnabled())


if __name__ == '__main__':
Expand Down