diff --git a/Orange/widgets/visualize/owpythagorastree.py b/Orange/widgets/visualize/owpythagorastree.py index 5440b7c7966..5eb1c63d10f 100644 --- a/Orange/widgets/visualize/owpythagorastree.py +++ b/Orange/widgets/visualize/owpythagorastree.py @@ -57,6 +57,8 @@ class Outputs: graph_name = 'scene' # Settings + settingsHandler = settings.DomainContextHandler() + depth_limit = settings.ContextSetting(10) target_class_index = settings.ContextSetting(0) size_calc_idx = settings.Setting(0) @@ -73,8 +75,7 @@ def __init__(self): super().__init__() # Instance variables self.model = None - self.instances = None - self.clf_dataset = None + self.data = None # The tree adapter instance which is passed from the outside self.tree_adapter = None self.legend = None @@ -147,18 +148,12 @@ def __init__(self): @Inputs.tree def set_tree(self, model=None): """When a different tree is given.""" + self.closeContext() self.clear() self.model = model if model is not None: - self.instances = model.instances - # this bit is important for the regression classifier - if self.instances is not None and \ - self.instances.domain != model.domain: - self.clf_dataset = self.instances.transform(self.model.domain) - else: - self.clf_dataset = self.instances - + self.data = model.instances self.tree_adapter = self._get_tree_adapter(self.model) self.ptree.clear() @@ -177,30 +172,30 @@ def set_tree(self, model=None): self._update_main_area() - # The target class can also be passed from the meta properties - # This must be set after `_update_target_class_combo` - if hasattr(model, 'meta_target_class_index'): - self.target_class_index = model.meta_target_class_index - self.update_colors() + self.openContext(self.model) - # Get meta variables describing what the settings should look like - # if the tree is passed from the Pythagorean forest widget. - if hasattr(model, 'meta_size_calc_idx'): - self.size_calc_idx = model.meta_size_calc_idx - self.update_size_calc() + self.update_depth() - # TODO There is still something wrong with this - # if hasattr(model, 'meta_depth_limit'): - # self.depth_limit = model.meta_depth_limit - # self.update_depth() + # The forest widget sets the following attributes on the tree, + # describing the settings on the forest widget. To keep the tree + # looking the same as on the forest widget, we prefer these settings to + # context settings, if set. + if hasattr(model, "meta_target_class_index"): + self.target_class_index = model.meta_target_class_index + self.update_colors() + if hasattr(model, "meta_size_calc_idx"): + self.size_calc_idx = model.meta_size_calc_idx + self.update_size_calc() + if hasattr(model, "meta_depth_limit"): + self.depth_limit = model.meta_depth_limit + self.update_depth() - self.Outputs.annotated_data.send(create_annotated_table(self.instances, None)) + self.Outputs.annotated_data.send(create_annotated_table(self.data, None)) def clear(self): """Clear all relevant data from the widget.""" self.model = None - self.instances = None - self.clf_dataset = None + self.data = None self.tree_adapter = None if self.legend is not None: @@ -228,6 +223,8 @@ def update_size_calc(self): self.invalidate_tree() def redraw(self): + if self.data is None: + return self.tree_adapter.shuffle_children() self.invalidate_tree() @@ -307,16 +304,21 @@ def onDeleteWidget(self): def commit(self): """Commit the selected data to output.""" - if self.instances is None: + if self.data is None: self.Outputs.selected_data.send(None) self.Outputs.annotated_data.send(None) return - nodes = [i.tree_node.label for i in self.scene.selectedItems() - if isinstance(i, SquareGraphicsItem)] + + nodes = [ + i.tree_node.label for i in self.scene.selectedItems() + if isinstance(i, SquareGraphicsItem) + ] data = self.tree_adapter.get_instances_in_nodes(nodes) self.Outputs.selected_data.send(data) selected_indices = self.tree_adapter.get_indices(nodes) - self.Outputs.annotated_data.send(create_annotated_table(self.instances, selected_indices)) + self.Outputs.annotated_data.send( + create_annotated_table(self.data, selected_indices) + ) def send_report(self): """Send report.""" @@ -327,9 +329,9 @@ def _update_target_class_combo(self): label = [x for x in self.target_class_combo.parent().children() if isinstance(x, QLabel)][0] - if self.instances.domain.has_discrete_class: + if self.data.domain.has_discrete_class: label_text = 'Target class' - values = [c.title() for c in self.instances.domain.class_vars[0].values] + values = [c.title() for c in self.data.domain.class_vars[0].values] values.insert(0, 'None') else: label_text = 'Node color' @@ -342,7 +344,7 @@ def _update_legend_colors(self): if self.legend is not None: self.scene.removeItem(self.legend) - if self.instances.domain.has_discrete_class: + if self.data.domain.has_discrete_class: self._classification_update_legend_colors() else: self._regression_update_legend_colors() @@ -375,14 +377,14 @@ def _get_colors_domain(domain): # The colors are the class mean if self.target_class_index == 1: - values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y)) + values = (np.min(self.data.Y), np.max(self.data.Y)) colors = _get_colors_domain(self.model.domain) while len(values) != len(colors): values.insert(1, -1) items = list(zip(values, colors)) # Colors are the stddev elif self.target_class_index == 2: - values = (0, np.std(self.clf_dataset.Y)) + values = (0, np.std(self.data.Y)) colors = _get_colors_domain(self.model.domain) while len(values) != len(colors): values.insert(1, -1) diff --git a/Orange/widgets/visualize/owpythagoreanforest.py b/Orange/widgets/visualize/owpythagoreanforest.py index 3cfe7d041d0..25183aeca0c 100644 --- a/Orange/widgets/visualize/owpythagoreanforest.py +++ b/Orange/widgets/visualize/owpythagoreanforest.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional from AnyQt.QtCore import Qt, QRectF, QSize, QPointF, QSizeF, QModelIndex, \ - QItemSelection, QT_VERSION + QItemSelection, QItemSelectionModel, QT_VERSION from AnyQt.QtGui import QPainter, QPen, QColor, QBrush, QMouseEvent from AnyQt.QtWidgets import QSizePolicy, QGraphicsScene, QLabel, QSlider, \ QListView, QStyledItemDelegate, QStyleOptionViewItem, QStyle @@ -174,11 +174,15 @@ class Outputs: graph_name = 'scene' # Settings + settingsHandler = settings.DomainContextHandler() + depth_limit = settings.ContextSetting(10) target_class_index = settings.ContextSetting(0) size_calc_idx = settings.Setting(0) zoom = settings.Setting(200) + selected_index = settings.ContextSetting(None) + SIZE_CALCULATION = [ ('Normal', lambda x: x), ('Square root', lambda x: sqrt(x)), @@ -199,7 +203,6 @@ def __init__(self): self.rf_model = None self.forest = None self.instances = None - self.clf_dataset = None self.color_palette = None @@ -265,29 +268,32 @@ def __init__(self): @Inputs.random_forest def set_rf(self, model=None): """When a different forest is given.""" + self.closeContext() self.clear() self.rf_model = model if model is not None: self.forest = self._get_forest_adapter(self.rf_model) self.forest_model[:] = self.forest.trees - self.instances = model.instances - # This bit is important for the regression classifier - if self.instances is not None and self.instances.domain != model.domain: - self.clf_dataset = self.instances.transform(self.rf_model.domain) - else: - self.clf_dataset = self.instances self._update_info_box() self._update_target_class_combo() self._update_depth_slider() + self.openContext(model) + # Restore item selection + if self.selected_index is not None: + index = self.list_view.model().index(self.selected_index) + selection = QItemSelection(index, index) + self.list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + def clear(self): """Clear all relevant data from the widget.""" self.rf_model = None self.forest = None self.forest_model.clear() + self.selected_index = None self._clear_info_box() self._clear_target_class_combo() @@ -342,19 +348,19 @@ def onDeleteWidget(self): super().onDeleteWidget() self.clear() - def commit(self, selection): - # type: (QItemSelection) -> None + def commit(self, selection: QItemSelection) -> None: """Commit the selected tree to output.""" selected_indices = selection.indexes() if not len(selected_indices): + self.selected_index = None self.Outputs.tree.send(None) return - selected_index, = selection.indexes() + # We only allow selecting a single tree so there will always be one index + self.selected_index = selected_indices[0].row() - idx = selected_index.row() - tree = self.rf_model.trees[idx] + tree = self.rf_model.trees[self.selected_index] tree.instances = self.instances tree.meta_target_class_index = self.target_class_index tree.meta_size_calc_idx = self.size_calc_idx diff --git a/Orange/widgets/visualize/tests/test_owpythagorastree.py b/Orange/widgets/visualize/tests/test_owpythagorastree.py index 301d6c697f5..f05dd53b7ff 100644 --- a/Orange/widgets/visualize/tests/test_owpythagorastree.py +++ b/Orange/widgets/visualize/tests/test_owpythagorastree.py @@ -383,3 +383,15 @@ def test_forest_tree_table(self): square.setSelected(True) tab = self.get_output(tree_w.Outputs.selected_data, widget=tree_w) self.assertGreater(len(tab), 0) + + def test_changing_data_restores_depth_from_previous_settings(self): + titanic_data = Table("titanic")[::50] + forest = RandomForestLearner(n_estimators=3)(titanic_data) + forest.instances = titanic_data + + self.send_signal(self.widget.Inputs.tree, forest.trees[0]) + self.widget.controls.depth_limit.setValue(1) + + # The domain is still the same, so restore the depth limit from before + self.send_signal(self.widget.Inputs.tree, forest.trees[1]) + self.assertEqual(self.widget.ptree._depth_limit, 1) diff --git a/Orange/widgets/visualize/tests/test_owpythagoreanforest.py b/Orange/widgets/visualize/tests/test_owpythagoreanforest.py index 49ac389e16d..c1c62f2e344 100644 --- a/Orange/widgets/visualize/tests/test_owpythagoreanforest.py +++ b/Orange/widgets/visualize/tests/test_owpythagoreanforest.py @@ -2,7 +2,7 @@ from unittest.mock import Mock -from AnyQt.QtCore import Qt +from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel from Orange.classification.random_forest import RandomForestLearner from Orange.data import Table @@ -201,3 +201,23 @@ def _callback(): # Check that individual squares all have the same color colors_same = [self._check_all_same(x) for x in zip(*colors)] self.assertTrue(all(colors_same)) + + def select_tree(self, idx: int) -> None: + list_view = self.widget.list_view + index = list_view.model().index(idx) + selection = QItemSelection(index, index) + list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect) + + def test_storing_selection(self): + # Select one of the trees + idx = 1 + self.send_signal(self.widget.Inputs.random_forest, self.titanic) + self.select_tree(idx) + # Clear input + self.send_signal(self.widget.Inputs.random_forest, None) + # Restore previous data; context settings should be restored + self.send_signal(self.widget.Inputs.random_forest, self.titanic) + + output = self.get_output(self.widget.Outputs.tree) + self.assertIsNotNone(output) + self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model)