Skip to content

Commit

Permalink
OWTSNE: Offload computation to separate thread
Browse files Browse the repository at this point in the history
  • Loading branch information
pavlin-policar committed Apr 5, 2019
1 parent d20a5fd commit 077fc02
Show file tree
Hide file tree
Showing 4 changed files with 644 additions and 295 deletions.
90 changes: 66 additions & 24 deletions Orange/projection/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def optimize(self, n_iter, inplace=False, propagate_exception=False, **kwargs):
new_embedding = self.embedding_.optimize(**kwargs)
table = Table(self.embedding.domain, new_embedding.view(np.ndarray),
self.embedding.Y, self.embedding.metas)
return TSNEModel(new_embedding, table, self.pre_domain)

new_model = TSNEModel(new_embedding, table, self.pre_domain)
new_model.name = self.name
return new_model


class TSNE(Projector):
Expand Down Expand Up @@ -400,7 +403,7 @@ def __init__(self, n_components=2, perplexity=30, learning_rate=200,
self.callbacks_every_iters = callbacks_every_iters
self.random_state = random_state

def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
def compute_affinities(self, X):
# Sparse data are not supported
if sp.issparse(X):
raise TypeError(
Expand All @@ -415,41 +418,75 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
if not isinstance(self.perplexity, Iterable):
raise ValueError(
"Perplexity should be an instance of `Iterable`, `%s` "
"given." % type(self.perplexity).__name__)
"given." % type(self.perplexity).__name__
)
affinities = openTSNE.affinity.Multiscale(
X, perplexities=self.perplexity, metric=self.metric,
method=self.neighbors, random_state=self.random_state, n_jobs=self.n_jobs)
X,
perplexities=self.perplexity,
metric=self.metric,
method=self.neighbors,
random_state=self.random_state,
n_jobs=self.n_jobs,
)
else:
if isinstance(self.perplexity, Iterable):
raise ValueError(
"Perplexity should be an instance of `float`, `%s` "
"given." % type(self.perplexity).__name__)
"given." % type(self.perplexity).__name__
)
affinities = openTSNE.affinity.PerplexityBasedNN(
X, perplexity=self.perplexity, metric=self.metric,
method=self.neighbors, random_state=self.random_state, n_jobs=self.n_jobs)
X,
perplexity=self.perplexity,
metric=self.metric,
method=self.neighbors,
random_state=self.random_state,
n_jobs=self.n_jobs,
)

# Create an initial embedding
return affinities

def compute_initialization(self, X):
# Compute the initial positions of individual points
if isinstance(self.initialization, np.ndarray):
initialization = self.initialization
elif self.initialization == "pca":
initialization = openTSNE.initialization.pca(
X, self.n_components, random_state=self.random_state)
X, self.n_components, random_state=self.random_state
)
elif self.initialization == "random":
initialization = openTSNE.initialization.random(
X, self.n_components, random_state=self.random_state)
X, self.n_components, random_state=self.random_state
)
else:
raise ValueError(
"Invalid initialization `%s`. Please use either `pca` or "
"`random` or provide a numpy array." % self.initialization)
"`random` or provide a numpy array." % self.initialization
)

embedding = openTSNE.TSNEEmbedding(
initialization, affinities, learning_rate=self.learning_rate,
theta=self.theta, min_num_intervals=self.min_num_intervals,
ints_in_interval=self.ints_in_interval, n_jobs=self.n_jobs,
return initialization

def prepare_embedding(self, affinities, initialization):
"""Prepare an embedding object with appropriate parameters, given some
affinities and initialization."""
return openTSNE.TSNEEmbedding(
initialization,
affinities,
learning_rate=self.learning_rate,
theta=self.theta,
min_num_intervals=self.min_num_intervals,
ints_in_interval=self.ints_in_interval,
n_jobs=self.n_jobs,
negative_gradient_method=self.negative_gradient_method,
callbacks=self.callbacks, callbacks_every_iters=self.callbacks_every_iters,
callbacks=self.callbacks,
callbacks_every_iters=self.callbacks_every_iters,
)

def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:
# Compute affinities and initial positions and prepare the embedding object
affinities = self.compute_affinities(X)
initialization = self.compute_initialization(X)
embedding = self.prepare_embedding(affinities, initialization)

# Run standard t-SNE optimization
embedding.optimize(
n_iter=self.early_exaggeration_iter, exaggeration=self.early_exaggeration,
Expand All @@ -462,13 +499,7 @@ def fit(self, X: np.ndarray, Y: np.ndarray = None) -> openTSNE.TSNEEmbedding:

return embedding

def __call__(self, data: Table) -> TSNEModel:
# Preprocess the data - convert discrete to continuous
data = self.preprocess(data)

# Run tSNE optimization
embedding = self.fit(data.X, data.Y)

def convert_embedding_to_model(self, data, embedding):
# The results should be accessible in an Orange table, which doesn't
# need the full embedding attributes and is cast into a regular array
n = self.n_components
Expand All @@ -484,6 +515,17 @@ def __call__(self, data: Table) -> TSNEModel:

return model

def __call__(self, data: Table) -> TSNEModel:
# Preprocess the data - convert discrete to continuous
data = self.preprocess(data)

# Run tSNE optimization
embedding = self.fit(data.X, data.Y)

# Convert the t-SNE embedding object to a TSNEModel and prepare the
# embedding table with t-SNE meta variables
return self.convert_embedding_to_model(data, embedding)

@staticmethod
def default_initialization(data, n_components=2, random_state=None):
return openTSNE.initialization.pca(
Expand Down
24 changes: 19 additions & 5 deletions Orange/widgets/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,9 +890,11 @@ def test_setup_graph(self, timeout=DEFAULT_TIMEOUT):
"""Plot should exist after data has been sent in order to be
properly set/updated"""
self.send_signal(self.widget.Inputs.data, self.data)

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.assertIsNotNone(self.widget.graph.scatterplot_item)

def test_default_attrs(self, timeout=DEFAULT_TIMEOUT):
Expand Down Expand Up @@ -966,16 +968,21 @@ def test_plot_once(self, timeout=DEFAULT_TIMEOUT):
table = Table("heart_disease")
self.widget.setup_plot = Mock()
self.widget.commit = Mock()

self.send_signal(self.widget.Inputs.data, table)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.setup_plot.assert_called_once()
self.widget.commit.assert_called_once()

self.widget.commit.reset_mock()
self.send_signal(self.widget.Inputs.data_subset, table[::10])
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.commit.reset_mock()
self.send_signal(self.widget.Inputs.data_subset, table[::10])
self.widget.setup_plot.assert_called_once()
self.widget.commit.assert_called_once()

Expand Down Expand Up @@ -1032,16 +1039,20 @@ def test_invalidated_embedding(self, timeout=DEFAULT_TIMEOUT):
self.widget.graph.update_coordinates = Mock()
self.widget.graph.update_point_props = Mock()
self.send_signal(self.widget.Inputs.data, self.data)
self.widget.graph.update_coordinates.assert_called_once()
self.widget.graph.update_point_props.assert_called_once()

if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.update_coordinates.assert_called()
self.widget.graph.update_point_props.assert_called()

self.widget.graph.update_coordinates.reset_mock()
self.widget.graph.update_point_props.reset_mock()
self.send_signal(self.widget.Inputs.data, self.data)
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.update_coordinates.assert_not_called()
self.widget.graph.update_point_props.assert_called_once()

Expand All @@ -1050,13 +1061,16 @@ def test_saved_selection(self, timeout=DEFAULT_TIMEOUT):
if self.widget.isBlocking():
spy = QSignalSpy(self.widget.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.widget.graph.select_by_indices(list(range(0, len(self.data), 10)))
settings = self.widget.settingsHandler.pack_data(self.widget)
w = self.create_widget(self.widget.__class__, stored_settings=settings)

self.send_signal(self.widget.Inputs.data, self.data, widget=w)
if w.isBlocking():
spy = QSignalSpy(w.blockingStateChanged)
self.assertTrue(spy.wait(timeout))

self.assertEqual(np.sum(w.graph.selection), 15)
np.testing.assert_equal(self.widget.graph.selection, w.graph.selection)

Expand Down
Loading

0 comments on commit 077fc02

Please sign in to comment.