Skip to content

Commit

Permalink
OwTSNE: Offload work to separate thread
Browse files Browse the repository at this point in the history
  • Loading branch information
pavlin-policar committed Feb 15, 2019
1 parent bd906f7 commit db0a98e
Show file tree
Hide file tree
Showing 5 changed files with 701 additions and 285 deletions.
19 changes: 12 additions & 7 deletions Orange/projection/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:

return embedding

def __call__(self, data: Table) -> TSNEModel:
# Preprocess the data - convert discrete to continuous
data = self.preprocess(data)

# Run tSNE optimization
embedding = self.fit(data.X, data.Y)

def convert_embedding_to_model(self, data, embedding):
# The results should be accessible in an Orange table, which doesn't
# need the full embedding attributes and is cast into a regular array
n = self.n_components
Expand All @@ -518,6 +512,17 @@ def __call__(self, data: Table) -> TSNEModel:

return model

def __call__(self, data: Table) -> TSNEModel:
# Preprocess the data - convert discrete to continuous
data = self.preprocess(data)

# Run tSNE optimization
embedding = self.fit(data.X, data.Y)

# Convert the t-SNE embedding object to a TSNEModel and prepare the
# embedding table with t-SNE meta variables
return self.convert_embedding_to_model(data, embedding)

@staticmethod
def default_initialization(data, n_components=2, random_state=None):
return openTSNE.initialization.pca(
Expand Down
37 changes: 30 additions & 7 deletions Orange/widgets/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,15 @@ def _compare_selected_annotated_domains(self, selected, annotated):
annotated_vars = annotated.domain.variables
self.assertLessEqual(set(selected_vars), set(annotated_vars))

def test_setup_graph(self):
def test_setup_graph(self, timeout=DEFAULT_TIMEOUT):
"""Plot should exist after data has been sent in order to be
properly set/updated"""
self.send_signal(self.widget.Inputs.data, self.data)

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.assertIsNotNone(self.widget.graph.scatterplot_item)

def test_default_attrs(self, timeout=DEFAULT_TIMEOUT):
Expand Down Expand Up @@ -934,16 +939,21 @@ def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
table = Table("heart_disease")
self.widget.setup_plot = Mock()
self.widget.commit = Mock()

self.send_signal(self.widget.Inputs.data, table)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.setup_plot.assert_called_once()
self.widget.commit.assert_called_once()

self.widget.commit.reset_mock()
self.send_signal(self.widget.Inputs.data_subset, table[::10])
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.commit.reset_mock()
self.send_signal(self.widget.Inputs.data_subset, table[::10])
self.widget.setup_plot.assert_called_once()
self.widget.commit.assert_called_once()

Expand Down Expand Up @@ -985,25 +995,38 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
self.widget.graph.update_coordinates = Mock()
self.widget.graph.update_point_props = Mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.graph.update_coordinates.assert_called_once()
self.widget.graph.update_point_props.assert_called_once()

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.update_coordinates.assert_called()
self.widget.graph.update_point_props.assert_called()

self.widget.graph.update_coordinates.reset_mock()
self.widget.graph.update_point_props.reset_mock()
self.send_signal(self.widget.Inputs.data, self.data)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.update_coordinates.assert_not_called()
self.widget.graph.update_point_props.assert_called_once()

def test_saved_selection(self):
def test_saved_selection(self, timeout=DEFAULT_TIMEOUT):
self.send_signal(self.widget.Inputs.data, self.data)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.select_by_indices(list(range(0, len(self.data), 10)))
settings = self.widget.settingsHandler.pack_data(self.widget)
w = self.create_widget(self.widget.__class__, stored_settings=settings)

self.send_signal(self.widget.Inputs.data, self.data, widget=w)
if w.isBlocking():
spy = QSignalSpy(w.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.assertEqual(np.sum(w.graph.selection), 15)
np.testing.assert_equal(self.widget.graph.selection, w.graph.selection)

Expand Down
Loading

0 comments on commit db0a98e

Please sign in to comment.