Skip to content

Commit

Permalink
OWtSNE: Use concurrent widget mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
pavlin-policar committed Mar 1, 2019
1 parent d6bdc4b commit dc1f337
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 90 deletions.
92 changes: 11 additions & 81 deletions Orange/widgets/unsupervised/owtsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from AnyQt.QtCore import Qt, pyqtSlot as Slot, pyqtSignal as Signal, QObject
from AnyQt.QtCore import Qt, pyqtSlot as Slot
from AnyQt.QtWidgets import QFormLayout

from Orange.data import Table, Domain
Expand All @@ -15,7 +15,7 @@
from Orange.projection import manifold
from Orange.widgets import gui
from Orange.widgets.settings import Setting, SettingProvider
from Orange.widgets.utils.concurrent import FutureWatcher
from Orange.widgets.utils.concurrent import FutureWatcher, TaskState, InterruptRequested
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase
from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget
Expand All @@ -26,76 +26,6 @@
_DEFAULT_PCA_COMPONENTS = 20


class TaskState(QObject):
status_changed = Signal(str)
_p_status_changed = Signal(str)

progress_changed = Signal(float)
_p_progress_changed = Signal(float)

partial_result_ready = Signal(object)
_p_partial_result_ready = Signal(object)

def __init__(self, *args):
super().__init__(*args)
self.__future = None
self.watcher = FutureWatcher()
self.__interuption_requested = False
self.__progress = 0
# Helpers to route the signal emits via a this object's queue.
# This ensures 'atomic' disconnect from signals for targets/slots
# in the same thread. Requires that the event loop is running in this
# object's thread.
self._p_status_changed.connect(
self.status_changed, Qt.QueuedConnection)
self._p_progress_changed.connect(
self.progress_changed, Qt.QueuedConnection)
self._p_partial_result_ready.connect(
self.partial_result_ready, Qt.QueuedConnection)

@property
def future(self):
# type: () -> Future
return self.__future

def set_status(self, text):
self._p_status_changed.emit(text)

def set_progress_value(self, value):
if round(value, 1) > round(self.__progress, 1):
# Only emit progress when it has changed sufficiently
self._p_progress_changed.emit(value)
self.__progress = value

def set_partial_results(self, value):
self._p_partial_result_ready.emit(value)

def is_interuption_requested(self):
return self.__interuption_requested

def start(self, executor, func=None):
# type: (futures.Executor, Callable[[], Any]) -> Future
assert self.future is None
assert not self.__interuption_requested
self.__future = executor.submit(func)
self.watcher.setFuture(self.future)
return self.future

def cancel(self):
assert not self.__interuption_requested
self.__interuption_requested = True
if self.future is not None:
rval = self.future.cancel()
else:
# not even scheduled yet
rval = True
return rval


class InteruptRequested(BaseException):
pass


class Task(namespace):
"""Completely determines the t-SNE task spec and intermediate results."""
data = None # type: Optional[Table]
Expand Down Expand Up @@ -203,8 +133,8 @@ def compute_tsne(task, state, progress_callback=None):
)
state.set_partial_results(("tsne_embedding", task))

if state.is_interuption_requested():
raise InteruptRequested()
if state.is_interruption_requested():
raise InterruptRequested()

total_iterations_needed = tsne.early_exaggeration_iter + tsne.n_iter
# If optimization has already been partially run, then the number of
Expand All @@ -231,8 +161,8 @@ def compute_tsne(task, state, progress_callback=None):
progress_callback((task.iterations_done - initial_iterations_done)
/ actual_iterations_needed)

if state.is_interuption_requested():
raise InteruptRequested()
if state.is_interruption_requested():
raise InterruptRequested()

# Run regular optimization phase
while task.iterations_done < total_iterations_needed:
Expand All @@ -249,8 +179,8 @@ def compute_tsne(task, state, progress_callback=None):
progress_callback((task.iterations_done - initial_iterations_done)
/ actual_iterations_needed)

if state.is_interuption_requested():
raise InteruptRequested()
if state.is_interruption_requested():
raise InterruptRequested()

@classmethod
def run(cls, task, state):
Expand Down Expand Up @@ -298,8 +228,8 @@ def _progress_callback(val):
100 * progress_done + 100 * val * job_weight
)

if state.is_interuption_requested():
raise InteruptRequested()
if state.is_interruption_requested():
raise InterruptRequested()

# Execute the job
job(progress_callback=_progress_callback)
Expand All @@ -308,7 +238,7 @@ def _progress_callback(val):
progress_done += job_weight
state.set_progress_value(100 * progress_done)

except InteruptRequested:
except InterruptRequested:
pass

return task
Expand Down
20 changes: 19 additions & 1 deletion Orange/widgets/unsupervised/tests/test_owtsne.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,19 @@ def setUp(self):
self.domain = Domain(self.attributes, class_vars=self.class_var)
self.empty_domain = Domain([], class_vars=self.class_var)

def tearDown(self):
# Some tests may not wait for the widget to finish, and the patched
# methods might be unpatched before the widget finishes, resulting in
# a very confusing crash.
self.widget.cancel()
try:
self.tsne.stop()
self.tsne_model.stop()
# If `restore_mocked_functions` was called, stopping the patchers will raise
except RuntimeError as e:
if str(e) != "stop called on unstarted patcher":
raise e

def restore_mocked_functions(self):
self.tsne.stop()
self.tsne_model.stop()
Expand Down Expand Up @@ -187,15 +200,20 @@ def _check_exaggeration(call, exaggeration):
self.assertIn("exaggeration", kwargs)
self.assertEqual(kwargs["exaggeration"], exaggeration)

# Since optimize needs to return a valid TSNEModel instance and it is
# impossible to return `self` in a mock, we'll prepare this one ahead
# of time and use this
optimize.return_value = DummyTSNE()(self.data)

# Set value to 1
self.widget.controls.exaggeration.setValue(1)
self.send_signal(self.widget.Inputs.data, self.data)
self.wait_until_stop_blocking()
_check_exaggeration(optimize, 1)

# Reset and clear state
optimize.reset_mock()
self.send_signal(self.widget.Inputs.data, None)
optimize.reset_mock()

# Change to 3
self.widget.controls.exaggeration.setValue(3)
Expand Down
8 changes: 0 additions & 8 deletions Orange/widgets/visualize/utils/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,6 @@ def _handle_subset_data(self):
elif self.subset_indices - ids:
self.Warning.subset_not_subset()

if self._invalidated:
self._invalidated = False
self.setup_plot()
else:
self.graph.update_point_props()

self.commit()

def get_subset_mask(self):
if not self.subset_indices:
return None
Expand Down

0 comments on commit dc1f337

Please sign in to comment.