From 97d355a3a8e7052f0faf2a7f018ee1f2235444ec Mon Sep 17 00:00:00 2001 From: Ales Erjavec Date: Mon, 1 Oct 2018 12:42:00 +0200 Subject: [PATCH] owhierarchicalclustering: Use selection indices for selection restore Change selection serialization and add dendrogram structure validation on restore. I.e. only restore if the linkage matrix matches (small leeway for inexact FP math is allowed). --- .../unsupervised/owhierarchicalclustering.py | 181 ++++++++++++++++-- .../tests/test_owhierarchicalclustering.py | 12 ++ 2 files changed, 175 insertions(+), 18 deletions(-) diff --git a/Orange/widgets/unsupervised/owhierarchicalclustering.py b/Orange/widgets/unsupervised/owhierarchicalclustering.py index 5b51a325589..923ab761d68 100644 --- a/Orange/widgets/unsupervised/owhierarchicalclustering.py +++ b/Orange/widgets/unsupervised/owhierarchicalclustering.py @@ -4,7 +4,9 @@ from collections import namedtuple, OrderedDict from itertools import chain from contextlib import contextmanager -from typing import List, Tuple, Dict, Optional # pylint: disable=unused-import + +import typing +from typing import Any, List, Tuple, Dict, Optional, Set import numpy as np @@ -752,6 +754,43 @@ def mousePressEvent(self, event): self.set_selected_clusters([]) +class SaveStateSettingsHandler(settings.SettingsHandler): + """ + A settings handler that delegates session data store/restore to the + OWWidget instance. + + The OWWidget subclass must implement `save_state() -> Dict[str, Any]` and + `set_restore_state(state: Dict[str, Any])` methods. + """ + def initialize(self, instance, data=None): + super().initialize(instance, data) + if data is not None and "__session_state_data" in data: + session_data = data["__session_state_data"] + instance.set_restore_state(session_data) + + def pack_data(self, widget): + # type: (widget.OWWidget) -> dict + res = super().pack_data(widget) + state = widget.save_state() + if state: + assert "__session_state_data" not in res + res["__session_state_data"] = state + return res + + +class _DomainContextHandler(settings.DomainContextHandler, + SaveStateSettingsHandler): + pass + + +if typing.TYPE_CHECKING: + #: Encoded selection state for persistent storage. + #: This is a list of tuples of leaf indices in the selection and + #: a (N, 3) linkage matrix for validation (the 4-th column from scipy + #: is omitted). + SelectionState = Tuple[List[Tuple[int]], List[Tuple[int, int, float]]] + + class OWHierarchicalClustering(widget.OWWidget): name = "Hierarchical Clustering" description = "Display a dendrogram of a hierarchical clustering " \ @@ -767,7 +806,7 @@ class Outputs: selected_data = Output("Selected Data", Orange.data.Table, default=True) annotated_data = Output(ANNOTATED_DATA_SIGNAL_NAME, Orange.data.Table) - settingsHandler = settings.DomainContextHandler() + settingsHandler = _DomainContextHandler() #: Selected linkage linkage = settings.Setting(1) @@ -788,9 +827,6 @@ class Outputs: cut_ratio = settings.Setting(75.0) #: Number of top clusters to select top_n = settings.Setting(3) - #: Selected trees - selected_trees = settings.ContextSetting([]) - #: Dendrogram zoom factor zoom_factor = settings.Setting(0) @@ -810,6 +846,9 @@ class Outputs: class Error(widget.OWWidget.Error): not_finite_distances = Msg("Some distances are infinite") + #: Stored (manual) selection state (from a saved workflow) to restore. + __pending_selection_restore = None # type: Optional[SelectionState] + def __init__(self): super().__init__() @@ -819,7 +858,6 @@ def __init__(self): self.root = None self._displayed_root = None self.cutoff_height = 0.0 - self._open_context = False gui.comboBox( self.controlArea, self, "linkage", items=LINKAGE, box="Linkage", @@ -1019,6 +1057,12 @@ def axis_view(orientation): @Inputs.distances def set_distances(self, matrix): + if self.__pending_selection_restore is not None: + selection_state = self.__pending_selection_restore + else: + # save the current selection to (possibly) restore later + selection_state = self._save_selection() + self.error() self.Error.clear() if matrix is not None: @@ -1038,6 +1082,11 @@ def set_distances(self, matrix): self._set_items(None) self._invalidate_clustering() + # Can now attempt to restore session state from a saved workflow. + if self.root and selection_state is not None: + self._restore_selection(selection_state) + self.__pending_selection_restore = None + self.unconditional_commit() def _set_items(self, items, axis=1): @@ -1063,7 +1112,6 @@ def _set_items(self, items, axis=1): else: self.annotation = "Enumeration" self.openContext(items.domain) - self._open_context = True self.label_cb.setCurrentIndex(model.indexOf(self.annotation)) else: name_option = bool( @@ -1146,20 +1194,70 @@ def _update_labels(self): self.labels.set_labels(labels) self.labels.setMinimumWidth(1 if labels else -1) - def _restore_selection(self): - if self.selection_method == 0 and self.matrix is not None \ - and isinstance(self.matrix.row_items, Orange.data.Table): - if not self._open_context: - self.openContext(self.matrix.row_items.domain) - select_items = [self.dendrogram.item(t) - for t in self.dendrogram._items - if hash(t) in self.selected_trees] - self.dendrogram.set_selected_items(select_items) + def _restore_selection(self, state): + # type: (SelectionState) -> bool + """ + Restore the (manual) node selection state. + + Return True if successful; False otherwise. + """ + linkmatrix = self.linkmatrix + if self.selection_method == 0 and self.root: + selected, linksaved = state + linkstruct = np.array(linksaved, dtype=float) + selected = set(selected) # type: Set[Tuple[int]] + if not selected: + return False + if linkmatrix.shape[0] != linkstruct.shape[0]: + return False + # check that the linkage matrix structure matches. Use isclose for + # the height column to account for inexact floating point math + # (e.g. summation order in different ?gemm implementations for + # euclidean distances, ...) + if np.any(linkstruct[:, :2] != linkmatrix[:, :2]) or \ + not np.all(np.isclose(linkstruct[:, 2], linkstruct[:, 2])): + return False + selection = [] + indices = np.array([n.value.index for n in leaves(self.root)], + dtype=int) + # mapping from ranges to display (pruned) nodes + mapping = {node.value.range: node + for node in postorder(self._displayed_root)} + for node in postorder(self.root): # type: Tree + r = tuple(indices[node.value.first: node.value.last]) + if r in selected: + if node.value.range not in mapping: + # the node was pruned from display and cannot be + # selected + break + selection.append(mapping[node.value.range]) + selected.remove(r) + if not selected: + break # found all, nothing more to do + if selection and selected: + # Could not restore all selected nodes (only partial match) + return False + + self._set_selected_nodes(selection) + return True + return False + + def _set_selected_nodes(self, selection): + # type: (List[Tree]) -> None + """ + Set the nodes in `selection` to be the current selected nodes. + + The selection nodes must be subtrees of the current `_displayed_root`. + """ + self.dendrogram.selectionChanged.disconnect(self._invalidate_output) + try: + self.dendrogram.set_selected_clusters(selection) + finally: + self.dendrogram.selectionChanged.connect(self._invalidate_output) def _invalidate_clustering(self): self._update() self._update_labels() - self._restore_selection() self._invalidate_output() def _invalidate_output(self): @@ -1189,7 +1287,6 @@ def commit(self): return selection = self.dendrogram.selected_nodes() - self.selected_trees = list(map(hash, selection)) selection = sorted(selection, key=lambda c: c.value.first) indices = [leaf.value.index for leaf in leaves(self.root)] @@ -1413,6 +1510,54 @@ def _selection_edited(self): self._selection_method_changed() self._invalidate_output() + def _save_selection(self): + # Save the current manual node selection state + selection_state = None + if self.selection_method == 0 and self.root: + assert self.linkmatrix is not None + linkmat = [(int(_0), int(_1), _2) + for _0, _1, _2 in self.linkmatrix[:, :3].tolist()] + nodes_ = self.dendrogram.selected_nodes() + # match the display (pruned) nodes back (by ranges) + mapping = {node.value.range: node for node in postorder(self.root)} + nodes = [mapping[node.value.range] for node in nodes_] + indices = [tuple(node.value.index for node in leaves(node)) + for node in nodes] + if nodes: + selection_state = (indices, linkmat) + return selection_state + + def save_state(self): + # type: () -> Dict[str, Any] + """ + Save state for `set_restore_state` + """ + selection = self._save_selection() + res = {"version": (0, 0, 0)} + if selection is not None: + res["selection_state"] = selection + return res + + def set_restore_state(self, state): + # type: (Dict[str, Any]) -> bool + """ + Restore session data from a saved state. + + Parameters + ---------- + state : Dict[str, Any] + + NOTE + ---- + This is method called while the instance (self) is being constructed, + even before its `__init__` is called. Consider `self` to be only a + `QObject` at this stage. + """ + if "selection_state" in state: + selection = state["selection_state"] + self.__pending_selection_restore = selection + return True + def __zoom_in(self): def clip(minval, maxval, val): return min(max(val, minval), maxval) diff --git a/Orange/widgets/unsupervised/tests/test_owhierarchicalclustering.py b/Orange/widgets/unsupervised/tests/test_owhierarchicalclustering.py index 1f97b9ee7d5..839222cc37c 100644 --- a/Orange/widgets/unsupervised/tests/test_owhierarchicalclustering.py +++ b/Orange/widgets/unsupervised/tests/test_owhierarchicalclustering.py @@ -131,3 +131,15 @@ def test_retain_selection(self): self.assertIsNotNone(self.get_output(self.widget.Outputs.selected_data)) self.send_signal(self.widget.Inputs.distances, self.distances) self.assertIsNotNone(self.get_output(self.widget.Outputs.selected_data)) + + def test_restore_state(self): + self.send_signal(self.widget.Inputs.distances, self.distances) + self._select_data() + ids_1 = self.get_output(self.widget.Outputs.selected_data).ids + state = self.widget.settingsHandler.pack_data(self.widget) + w = self.create_widget( + OWHierarchicalClustering, stored_settings=state + ) + self.send_signal(self.widget.Inputs.distances, self.distances) + ids_2 = self.get_output(self.widget.Outputs.selected_data).ids + self.assertSequenceEqual(list(ids_1), list(ids_2))