diff --git a/Orange/widgets/unsupervised/owtsne.py b/Orange/widgets/unsupervised/owtsne.py index 7fb8a0cc6f0..6ba36436518 100644 --- a/Orange/widgets/unsupervised/owtsne.py +++ b/Orange/widgets/unsupervised/owtsne.py @@ -6,7 +6,7 @@ import numpy as np -from AnyQt.QtCore import Qt, pyqtSlot as Slot, pyqtSignal as Signal, QObject +from AnyQt.QtCore import Qt, pyqtSlot as Slot from AnyQt.QtWidgets import QFormLayout from Orange.data import Table, Domain @@ -15,7 +15,7 @@ from Orange.projection import manifold from Orange.widgets import gui from Orange.widgets.settings import Setting, SettingProvider -from Orange.widgets.utils.concurrent import FutureWatcher +from Orange.widgets.utils.concurrent import FutureWatcher, TaskState, InterruptRequested from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotBase from Orange.widgets.visualize.utils.widget import OWDataProjectionWidget @@ -26,76 +26,6 @@ _DEFAULT_PCA_COMPONENTS = 20 -class TaskState(QObject): - status_changed = Signal(str) - _p_status_changed = Signal(str) - - progress_changed = Signal(float) - _p_progress_changed = Signal(float) - - partial_result_ready = Signal(object) - _p_partial_result_ready = Signal(object) - - def __init__(self, *args): - super().__init__(*args) - self.__future = None - self.watcher = FutureWatcher() - self.__interuption_requested = False - self.__progress = 0 - # Helpers to route the signal emits via a this object's queue. - # This ensures 'atomic' disconnect from signals for targets/slots - # in the same thread. Requires that the event loop is running in this - # object's thread. - self._p_status_changed.connect( - self.status_changed, Qt.QueuedConnection) - self._p_progress_changed.connect( - self.progress_changed, Qt.QueuedConnection) - self._p_partial_result_ready.connect( - self.partial_result_ready, Qt.QueuedConnection) - - @property - def future(self): - # type: () -> Future - return self.__future - - def set_status(self, text): - self._p_status_changed.emit(text) - - def set_progress_value(self, value): - if round(value, 1) > round(self.__progress, 1): - # Only emit progress when it has changed sufficiently - self._p_progress_changed.emit(value) - self.__progress = value - - def set_partial_results(self, value): - self._p_partial_result_ready.emit(value) - - def is_interuption_requested(self): - return self.__interuption_requested - - def start(self, executor, func=None): - # type: (futures.Executor, Callable[[], Any]) -> Future - assert self.future is None - assert not self.__interuption_requested - self.__future = executor.submit(func) - self.watcher.setFuture(self.future) - return self.future - - def cancel(self): - assert not self.__interuption_requested - self.__interuption_requested = True - if self.future is not None: - rval = self.future.cancel() - else: - # not even scheduled yet - rval = True - return rval - - -class InteruptRequested(BaseException): - pass - - class Task(namespace): """Completely determines the t-SNE task spec and intermediate results.""" data = None # type: Optional[Table] @@ -203,8 +133,8 @@ def compute_tsne(task, state, progress_callback=None): ) state.set_partial_results(("tsne_embedding", task)) - if state.is_interuption_requested(): - raise InteruptRequested() + if state.is_interruption_requested(): + raise InterruptRequested() total_iterations_needed = tsne.early_exaggeration_iter + tsne.n_iter # If optimization has already been partially run, then the number of @@ -231,8 +161,8 @@ def compute_tsne(task, state, progress_callback=None): progress_callback((task.iterations_done - initial_iterations_done) / actual_iterations_needed) - if state.is_interuption_requested(): - raise InteruptRequested() + if state.is_interruption_requested(): + raise InterruptRequested() # Run regular optimization phase while task.iterations_done < total_iterations_needed: @@ -249,8 +179,8 @@ def compute_tsne(task, state, progress_callback=None): progress_callback((task.iterations_done - initial_iterations_done) / actual_iterations_needed) - if state.is_interuption_requested(): - raise InteruptRequested() + if state.is_interruption_requested(): + raise InterruptRequested() @classmethod def run(cls, task, state): @@ -298,8 +228,8 @@ def _progress_callback(val): 100 * progress_done + 100 * val * job_weight ) - if state.is_interuption_requested(): - raise InteruptRequested() + if state.is_interruption_requested(): + raise InterruptRequested() # Execute the job job(progress_callback=_progress_callback) @@ -308,7 +238,7 @@ def _progress_callback(val): progress_done += job_weight state.set_progress_value(100 * progress_done) - except InteruptRequested: + except InterruptRequested: pass return task diff --git a/Orange/widgets/unsupervised/tests/test_owtsne.py b/Orange/widgets/unsupervised/tests/test_owtsne.py index 8f8d3b8fac0..4cc5e656918 100644 --- a/Orange/widgets/unsupervised/tests/test_owtsne.py +++ b/Orange/widgets/unsupervised/tests/test_owtsne.py @@ -50,6 +50,19 @@ def setUp(self): self.domain = Domain(self.attributes, class_vars=self.class_var) self.empty_domain = Domain([], class_vars=self.class_var) + def tearDown(self): + # Some tests may not wait for the widget to finish, and the patched + # methods might be unpatched before the widget finishes, resulting in + # a very confusing crash. + self.widget.cancel() + try: + self.tsne.stop() + self.tsne_model.stop() + # If `restore_mocked_functions` was called, stopping the patchers will raise + except RuntimeError as e: + if str(e) != "stop called on unstarted patcher": + raise e + def restore_mocked_functions(self): self.tsne.stop() self.tsne_model.stop() @@ -187,6 +200,11 @@ def _check_exaggeration(call, exaggeration): self.assertIn("exaggeration", kwargs) self.assertEqual(kwargs["exaggeration"], exaggeration) + # Since optimize needs to return a valid TSNEModel instance and it is + # impossible to return `self` in a mock, we'll prepare this one ahead + # of time and use this + optimize.return_value = DummyTSNE()(self.data) + # Set value to 1 self.widget.controls.exaggeration.setValue(1) self.send_signal(self.widget.Inputs.data, self.data) @@ -194,8 +212,8 @@ def _check_exaggeration(call, exaggeration): _check_exaggeration(optimize, 1) # Reset and clear state - optimize.reset_mock() self.send_signal(self.widget.Inputs.data, None) + optimize.reset_mock() # Change to 3 self.widget.controls.exaggeration.setValue(3) diff --git a/Orange/widgets/visualize/utils/widget.py b/Orange/widgets/visualize/utils/widget.py index e5060bb0e33..56dac68dd25 100644 --- a/Orange/widgets/visualize/utils/widget.py +++ b/Orange/widgets/visualize/utils/widget.py @@ -480,14 +480,6 @@ def _handle_subset_data(self): elif self.subset_indices - ids: self.Warning.subset_not_subset() - if self._invalidated: - self._invalidated = False - self.setup_plot() - else: - self.graph.update_point_props() - - self.commit() - def get_subset_mask(self): if not self.subset_indices: return None