From eddc60ce58a074e2a247ec41417a728267531dc6 Mon Sep 17 00:00:00 2001 From: Ales Erjavec Date: Mon, 15 Oct 2018 16:55:01 +0200 Subject: [PATCH] owlouvainclustering: Fix race conditions * Refactor the code to eliminate race conditions in the access/setting of `pca_projection`, `graph` and `partitions`. * Limit the number of executing tasks to one. --- .../unsupervised/owlouvainclustering.py | 570 +++++++++++++----- .../unsupervised/tests/test_owlouvain.py | 11 +- Orange/widgets/widget.py | 4 +- 3 files changed, 419 insertions(+), 166 deletions(-) diff --git a/Orange/widgets/unsupervised/owlouvainclustering.py b/Orange/widgets/unsupervised/owlouvainclustering.py index 10738c4252e..c290e7ebcd8 100644 --- a/Orange/widgets/unsupervised/owlouvainclustering.py +++ b/Orange/widgets/unsupervised/owlouvainclustering.py @@ -1,12 +1,20 @@ -from collections import deque -from concurrent.futures import Future # pylint: disable=unused-import +from functools import partial + +from concurrent import futures +from concurrent.futures import Future + from types import SimpleNamespace as namespace -from typing import Optional # pylint: disable=unused-import +from typing import ( + Optional, Callable, Tuple, Any +) import numpy as np -import networkx as nx # pylint: disable=unused-import -from AnyQt.QtCore import Qt, pyqtSignal as Signal, QObject -from AnyQt.QtWidgets import QSlider, QCheckBox, QWidget # pylint: disable=unused-import +import networkx as nx + +from AnyQt.QtCore import ( + Qt, QObject, QTimer, pyqtSignal as Signal, pyqtSlot as Slot +) +from AnyQt.QtWidgets import QSlider, QCheckBox, QWidget from Orange.clustering.louvain import table_to_knn_graph, Louvain from Orange.data import Table, DiscreteVariable @@ -16,7 +24,7 @@ Setting from Orange.widgets.utils.annotated_data import get_next_name, add_columns, \ ANNOTATED_DATA_SIGNAL_NAME -from Orange.widgets.utils.concurrent import ThreadExecutor +from Orange.widgets.utils.concurrent import FutureWatcher from Orange.widgets.utils.signals import Input, Output from Orange.widgets.widget import Msg @@ -35,65 +43,6 @@ METRICS = [('Euclidean', 'l2'), ('Manhattan', 'l1')] -class TaskQueue(QObject): - """Not really a task queue `per-se`. Running start will run the tasks in - the current list and cannot handle adding other tasks while running.""" - on_exception = Signal(Exception) - on_complete = Signal() - on_progress = Signal(float) - on_cancel = Signal() - - def __init__(self, parent=None): - super().__init__(parent=parent) - self.__tasks = deque() - self.__progress = 0 - self.__cancelled = False - - def cancel(self): - self.__cancelled = True - - def push(self, task): - self.__tasks.append(task) - - def __set_progress(self, progress): - # Only emit progress signal when the progress has changed sufficiently - if int(progress * 100) > int(self.__progress * 100): - self.on_progress.emit(progress) - self.__progress = progress - - def start(self): - num_tasks = len(self.__tasks) - - for idx, task_spec in enumerate(self.__tasks): - - if self.__cancelled: - self.on_cancel.emit() - return - - def __task_progress(percentage, index=idx): - current_progress = index / num_tasks - # How much progress can each task contribute to the total - # work to be done - task_percentage = 1 / len(self.__tasks) - # Convert the progress done by the task into the total - # progress to the task - relative_progress = task_percentage * percentage - self.__set_progress(current_progress + relative_progress) - - try: - if getattr(task_spec, 'progress_callback', False): - task_spec.task(progress_callback=__task_progress) - else: - task_spec.task() - self.__set_progress((idx + 1) / num_tasks) - - except Exception as e: # pylint: disable=broad-except - self.on_exception.emit(e) - return - - self.on_complete.emit() - - class OWLouvainClustering(widget.OWWidget): name = 'Louvain Clustering' description = 'Detects communities in a network of nearest neighbors.' @@ -120,7 +69,10 @@ class Outputs: metric_idx = ContextSetting(0) k_neighbors = ContextSetting(_DEFAULT_K_NEIGHBORS) resolution = ContextSetting(1.) - auto_commit = Setting(True) + auto_commit = Setting(False) + + class Information(widget.OWWidget.Information): + modified = Msg("Press commit to recompute clusters and send new data") class Error(widget.OWWidget.Error): empty_dataset = Msg('No features in data') @@ -131,119 +83,128 @@ def __init__(self): self.data = None # type: Optional[Table] self.preprocessed_data = None # type: Optional[Table] + self.pca_projection = None # type: Optional[Table] self.graph = None # type: Optional[nx.Graph] self.partition = None # type: Optional[np.array] - - self.__executor = ThreadExecutor(parent=self) - self.__future = None # type: Optional[Future] - self.__queue = None # type: Optional[TaskQueue] + # Use a executor with a single worker, to limit CPU overcommitment for + # cancelled tasks. The method does not have a fine cancellation + # granularity so we assure that there are not N - 1 jobs executing + # for no reason only to be thrown away. It would be better to use the + # global pool but implement a limit on jobs from this source. + self.__executor = futures.ThreadPoolExecutor(max_workers=1) + self.__task = None # type: Optional[TaskState] + self.__invalidated = False + # coalescing commit timer + self.__commit_timer = QTimer(self, singleShot=True) + self.__commit_timer.timeout.connect(self.commit) pca_box = gui.vBox(self.controlArea, 'PCA Preprocessing') self.apply_pca_cbx = gui.checkBox( pca_box, self, 'apply_pca', label='Apply PCA preprocessing', - callback=self._update_graph, + callback=self._invalidate_graph, ) # type: QCheckBox self.pca_components_slider = gui.hSlider( pca_box, self, 'pca_components', label='Components: ', minValue=2, - maxValue=_MAX_PCA_COMPONENTS, callback=self._update_pca_components, - tracking=False + maxValue=_MAX_PCA_COMPONENTS, + callback=self._invalidate_pca_projection, tracking=False ) # type: QSlider graph_box = gui.vBox(self.controlArea, 'Graph parameters') self.metric_combo = gui.comboBox( graph_box, self, 'metric_idx', label='Distance metric', - items=[m[0] for m in METRICS], callback=self._update_graph, + items=[m[0] for m in METRICS], callback=self._invalidate_graph, orientation=Qt.Horizontal, ) # type: gui.OrangeComboBox self.k_neighbors_spin = gui.spin( graph_box, self, 'k_neighbors', minv=1, maxv=_MAX_K_NEIGBOURS, label='k neighbors', controlWidth=80, alignment=Qt.AlignRight, - callback=self._update_graph, + callback=self._invalidate_graph, ) # type: gui.SpinBoxWFocusOut self.resolution_spin = gui.hSlider( graph_box, self, 'resolution', minValue=0, maxValue=5., step=1e-1, label='Resolution', intOnly=False, labelFormat='%.1f', - callback=self._update_resolution, tracking=False, + callback=self._invalidate_partition, tracking=False, ) # type: QSlider self.resolution_spin.parent().setToolTip( 'The resolution parameter affects the number of clusters to find. ' 'Smaller values tend to produce more clusters and larger values ' 'retrieve less clusters.' ) - self.apply_button = gui.auto_commit( self.controlArea, self, 'auto_commit', 'Apply', box=None, - commit=self.commit, + commit=lambda: self.commit(), + callback=lambda: self._on_auto_commit_changed(), ) # type: QWidget - def _update_graph(self): + def _invalidate_pca_projection(self): + self.pca_projection = None self._invalidate_graph() - self.commit() - - def _update_pca_components(self): - self._invalidate_pca_projection() - self.commit() + self._set_modified(True) - def _update_resolution(self): + def _invalidate_graph(self): + self.graph = None self._invalidate_partition() - self.commit() - - def _compute_pca_projection(self): - if self.pca_projection is None and self.apply_pca: - self.setStatusMessage('Computing PCA...') - - pca = PCA(n_components=self.pca_components, random_state=0) - model = pca(self.preprocessed_data) - self.pca_projection = model(self.preprocessed_data) - - def _compute_graph(self, progress_callback=None): - if self.graph is None: - self.setStatusMessage('Building graph...') + self._set_modified(True) - data = self.pca_projection if self.apply_pca else self.preprocessed_data - - self.graph = table_to_knn_graph( - data, k_neighbors=self.k_neighbors, - metric=METRICS[self.metric_idx][1], - progress_callback=progress_callback, - ) - - def _compute_partition(self): - if self.partition is None: - self.setStatusMessage('Detecting communities...') - self.setBlocking(True) - - louvain = Louvain(resolution=self.resolution) - self.partition = louvain.fit_predict(self.graph) - - def _processing_complete(self): - self.setStatusMessage('') - self.setBlocking(False) - self.progressBarFinished() + def _invalidate_partition(self): + self.partition = None + self._invalidate_output() + self.Information.modified() + self._set_modified(True) + + def _invalidate_output(self): + self.__invalidated = True + if self.__task is not None: + self.__cancel_task(wait=False) + + if self.auto_commit: + self.__commit_timer.start() + else: + self.__set_state_ready() + + def _set_modified(self, state): + """ + Mark the widget (GUI) as containing modified state. + """ + if self.data is None: + # does not apply when we have no data + state = False + elif self.auto_commit: + # does not apply when auto commit is on + state = False + self.Information.modified(shown=state) - def _handle_exceptions(self, ex): - self.Error.general_error(str(ex)) + def _on_auto_commit_changed(self): + if self.auto_commit and self.__invalidated: + self.commit() def cancel(self): """Cancel any running jobs.""" - if self.__future is not None: - assert self.__queue is not None - self.__queue.cancel() - self.__queue = None - self.__future.cancel() - self.__future = None + self.__cancel_task(wait=False) + self.__set_state_ready() def commit(self): + self.__commit_timer.stop() + self.__invalidated = False + self._set_modified(False) self.Error.clear() - # Kill any running jobs - self.cancel() + + # Cancel current running task + self.__cancel_task(wait=False) if self.data is None: + self.__set_state_ready() return # Make sure the dataset is ok if len(self.data.domain.attributes) < 1: self.Error.empty_dataset() + self.__set_state_ready() + return + + if self.partition is not None: + self.__set_state_ready() + self._send_data() return # Preprocess the dataset @@ -251,31 +212,134 @@ def commit(self): louvain = Louvain() self.preprocessed_data = louvain.preprocess(self.data) - # Prepare the tasks to run - queue = TaskQueue(parent=self) - - if self.pca_projection is None and self.apply_pca: - queue.push(namespace(task=self._compute_pca_projection)) + state = TaskState() - if self.graph is None: - queue.push(namespace(task=self._compute_graph, progress_callback=True)) - - if self.partition is None: - queue.push(namespace(task=self._compute_partition)) - - # Prepare callbacks - queue.on_progress.connect(lambda val: self.progressBarSet(100 * val)) - queue.on_complete.connect(self._on_complete) - queue.on_exception.connect(self._handle_exceptions) - self.__queue = queue + # Prepare/assemble the task(s) to run; reuse partial results + if self.apply_pca: + if self.pca_projection is not None: + data = self.pca_projection + pca_components = None + else: + data = self.preprocessed_data + pca_components = self.pca_components + else: + data = self.preprocessed_data + pca_components = None + + if self.graph is not None: + # run on graph only; no need to do PCA and k-nn search ... + graph = self.graph + k_neighbors = metric = None + else: + k_neighbors, metric = self.k_neighbors, METRICS[self.metric_idx][1] + graph = None + + if graph is None: + task = partial( + run_on_data, data, pca_components=pca_components, + k_neighbors=k_neighbors, metric=metric, + resolution=self.resolution, state=state + ) + else: + task = partial( + run_on_graph, graph, resolution=self.resolution, state=state + ) + state.run = task + self.__set_state_busy() + self.__start_task(task, state) + + @Slot(object) + def __set_partial_results(self, result): + # type: (Tuple[str, Any]) -> None + which, res = result + if which == "pca_projection": + assert isinstance(res, Table) and len(res) == len(self.data) + self.pca_projection = res + elif which == "graph": + assert isinstance(res, nx.Graph) + self.graph = res + elif which == "partition": + assert isinstance(res, np.ndarray) + self.partition = res + else: + assert False, which + + @Slot(object) + def __on_done(self, future): + # type: (Future['Results']) -> None + assert future.done() + assert self.__task is not None + assert self.__task.future is future + assert self.__task.watcher.future() is future + self.__task = None + self.__set_state_ready() + try: + result = future.result() + except Exception as err: + self.Error.general_error(str(err), exc_info=True) + else: + self.__set_results(result) + + @Slot(str) + def setStatusMessage(self, text): + super().setStatusMessage(text) + + @Slot(float) + def progressBarSet(self, value, *a, **kw): + super().progressBarSet(value, *a, **kw) + + def __set_state_ready(self): + self.progressBarFinished() + self.setBlocking(False) + self.setStatusMessage("") - # Run the task queue + def __set_state_busy(self): self.progressBarInit() self.setBlocking(True) - self.__future = self.__executor.submit(queue.start) - def _on_complete(self): - self._processing_complete() + def __start_task(self, task, state): + # type: (Callable[[], Any], TaskState) -> None + assert self.__task is None + state.status_changed.connect(self.setStatusMessage) + state.progress_changed.connect(self.progressBarSet) + state.partial_result_ready.connect(self.__set_partial_results) + state.watcher.done.connect(self.__on_done) + state.run = task + state.start(self.__executor, task) + self.__task = state + + def __cancel_task(self, wait=True): + # Cancel and dispose of the current task + if self.__task is not None: + state, self.__task = self.__task, None + state.cancel() + state.partial_result_ready.disconnect(self.__set_partial_results) + state.status_changed.disconnect(self.setStatusMessage) + state.progress_changed.disconnect(self.progressBarSet) + state.watcher.done.disconnect(self.__on_done) + if state.parent() is self: + state.setParent(None) + + if wait and state.future is not None: + futures.wait([state.future]) + + def __set_results(self, results): + # type: ('Results') -> None + # NOTE: All of these have already been set by __set_partial_results, + # we double check that they are aliases + if results.pca_projection is not None: + assert self.pca_components == results.pca_components + assert self.pca_projection is results.pca_projection + self.pca_projection = results.pca_projection + if results.graph is not None: + assert results.metric == METRICS[self.metric_idx][1] + assert results.k_neighbors == self.k_neighbors + assert self.graph is results.graph + self.graph = results.graph + if results.partition is not None: + assert results.resolution == self.resolution + assert self.partition is results.partition + self.partition = results.partition self._send_data() def _send_data(self): @@ -303,17 +367,6 @@ def _send_data(self): graph.set_items(new_table) self.Outputs.graph.send(graph) - def _invalidate_pca_projection(self): - self.pca_projection = None - self._invalidate_graph() - - def _invalidate_graph(self): - self.graph = None - self._invalidate_partition() - - def _invalidate_partition(self): - self.partition = None - @Inputs.data def set_data(self, data): self.closeContext() @@ -334,7 +387,7 @@ def set_data(self, data): self.Outputs.graph.send(None) # Clear internal state - self.preprocessed_data = None + self.clear() self._invalidate_pca_projection() if self.data is None: return @@ -349,8 +402,18 @@ def set_data(self, data): self.commit() + def clear(self): + self.__cancel_task(wait=False) + self.preprocessed_data = None + self.pca_projection = None + self.graph = None + self.partition = None + self.Error.clear() + self.Information.modified.clear() + def onDeleteWidget(self): - self.cancel() + self.__cancel_task(wait=True) + self.__executor.shutdown(True) super().onDeleteWidget() def send_report(self): @@ -366,6 +429,189 @@ def send_report(self): )) +class TaskState(QObject): + + status_changed = Signal(str) + _p_status_changed = Signal(str) + + progress_changed = Signal(float) + _p_progress_changed = Signal(float) + + partial_result_ready = Signal(object) + _p_partial_result_ready = Signal(object) + + def __init__(self, *args): + super().__init__(*args) + self.__future = None + self.watcher = FutureWatcher() + self.__interuption_requested = False + self.__progress = 0 + # Helpers to route the signal emits via a this object's queue. + # This ensures 'atomic' disconnect from signals for targets/slots + # in the same thread. Requires that the event loop is running in this + # object's thread. + self._p_status_changed.connect( + self.status_changed, Qt.QueuedConnection) + self._p_progress_changed.connect( + self.progress_changed, Qt.QueuedConnection) + self._p_partial_result_ready.connect( + self.partial_result_ready, Qt.QueuedConnection) + + @property + def future(self): + # type: () -> Future + return self.__future + + def set_status(self, text): + self._p_status_changed.emit(text) + + def set_progress_value(self, value): + if round(value, 1) > round(self.__progress, 1): + # Only emit progress when it has changed sufficiently + self._p_progress_changed.emit(value) + self.__progress = value + + def set_partial_results(self, value): + self._p_partial_result_ready.emit(value) + + def is_interuption_requested(self): + return self.__interuption_requested + + def start(self, executor, func=None): + # type: (futures.Executor, Callable[[], Any]) -> Future + assert self.future is None + assert not self.__interuption_requested + self.__future = executor.submit(func) + self.watcher.setFuture(self.future) + return self.future + + def cancel(self): + assert not self.__interuption_requested + self.__interuption_requested = True + if self.future is not None: + rval = self.future.cancel() + else: + # not event scheduled + rval = True + return rval + + +class InteruptRequested(BaseException): + pass + + +class Results(namespace): + pca_projection = None # type: Optional[Table] + pca_components = None # type: Optional[int] + k_neighbors = None # type: Optional[int] + metric = None # type: Optional[str] + graph = None # type: Optional[nx.Graph] + resolution = None # type: Optional[float] + partition = None # type: Optional[np.ndarray] + + +def run_on_data(data, pca_components, k_neighbors, metric, resolution, state): + # type: (Table, Optional[int], int, str, float, TaskState) -> Results + """ + Run the louvain clustering on `data`. + + state is used to report progress and partial results. Returns early if + `task.is_interuption_requested()` returns true. + + Parameters + ---------- + data : Table + Data table + pca_components : Optional[int] + If not `None` then the data is first projected onto first + `pca_components` principal components. + k_neighbors : int + metric : str + resolution : float + state : TaskState + + Returns + ------- + res : Results + """ + state = state # type: TaskState + res = Results( + pca_components=pca_components, k_neighbors=k_neighbors, metric=metric, + resolution=resolution, + ) + step = 0 + if state.is_interuption_requested(): + return res + if pca_components is not None: + steps = 3 + state.set_status("Computing PCA...") + pca = PCA(n_components=pca_components, random_state=0) + data = res.pca_projection = pca(data)(data) + assert isinstance(data, Table) + state.set_partial_results(("pca_projection", res.pca_projection)) + step += 1 + else: + steps = 2 + + if state.is_interuption_requested(): + return res + + state.set_progress_value(100. * step / steps) + state.set_status("Building graph...") + + def pcallback(val): + state.set_progress_value((100. * step + 100 * val) / steps) + if state.is_interuption_requested(): + raise InteruptRequested() + + try: + res.graph = graph = table_to_knn_graph( + data, k_neighbors=k_neighbors, metric=metric, + progress_callback=pcallback + ) + except InteruptRequested: + return res + + state.set_partial_results(("graph", res.graph)) + + step += 1 + state.set_progress_value(100 * step / steps) + state.set_status("Detecting communities...") + if state.is_interuption_requested(): + return res + + louvain = Louvain(resolution=resolution) + res.partition = louvain.fit_predict(graph) + state.set_partial_results(("partition", res.partition)) + return res + + +def run_on_graph(graph, resolution, state): + # type: (nx.Graph, float, TaskState) -> Results + """ + Run the louvain clustering on `graph`. + Parameters + ---------- + graph + resolution + state + + Returns + ------- + + """ + state = state # type: TaskState + res = Results(resolution=resolution) + louvain = Louvain(resolution=resolution) + state.set_status("Detecting communities...") + if state.is_interuption_requested(): + return res + partition = louvain.fit_predict(graph) + res.partition = partition + state.set_partial_results(("partition", res.partition)) + return res + + if __name__ == '__main__': from AnyQt.QtWidgets import QApplication # pylint: disable=ungrouped-imports import sys diff --git a/Orange/widgets/unsupervised/tests/test_owlouvain.py b/Orange/widgets/unsupervised/tests/test_owlouvain.py index f73832a8903..1bfe289fefa 100644 --- a/Orange/widgets/unsupervised/tests/test_owlouvain.py +++ b/Orange/widgets/unsupervised/tests/test_owlouvain.py @@ -21,11 +21,16 @@ def tearDown(self): self.widget.onDeleteWidget() super().tearDown() + def commit_and_wait(self, widget=None): + widget = self.widget if widget is None else widget + widget.commit() + self.wait_until_stop_blocking(widget) + def test_removing_data(self): self.send_signal(self.widget.Inputs.data, self.iris) - self.commit_and_wait() + self.commit_and_wait(self.widget) self.send_signal(self.widget.Inputs.data, None) - self.commit_and_wait() + self.commit_and_wait(self.widget) def test_clusters_ordered_by_size(self): """Cluster names should be sorted based on the number of instances.""" @@ -80,7 +85,7 @@ def test_do_not_recluster_on_same_data(self): table3 = table1.copy() table3.X[:, 0] = 1 - with patch.object(self.widget, 'commit') as commit: + with patch.object(self.widget, '_invalidate_output') as commit: self.send_signal(self.widget.Inputs.data, table1) self.commit_and_wait() call_count = commit.call_count diff --git a/Orange/widgets/widget.py b/Orange/widgets/widget.py index 43a32656706..9ef9a4c4231 100644 --- a/Orange/widgets/widget.py +++ b/Orange/widgets/widget.py @@ -13,7 +13,7 @@ ) from AnyQt.QtCore import ( Qt, QRect, QMargins, QByteArray, QDataStream, QBuffer, QSettings, - QUrl, pyqtSignal as Signal + QUrl, QThread, pyqtSignal as Signal ) from AnyQt.QtGui import QIcon, QKeySequence, QDesktopServices @@ -769,6 +769,7 @@ def setStatusMessage(self, text): This is a short status string to be displayed inline next to the instantiated widget icon in the canvas. """ + assert QThread.currentThread() == self.thread() if self.__statusMessage != text: self.__statusMessage = text self.statusMessageChanged.emit(text) @@ -815,6 +816,7 @@ def setBlocking(self, state=True): .. note:: Failure to clear this flag will block dependent nodes forever. """ + assert QThread.currentThread() is self.thread() if self.__blocking != state: self.__blocking = state self.blockingStateChanged.emit(state)