Skip to content

Commit

Permalink
Merge pull request #3777 from pavlin-policar/pythagoras-enh
Browse files Browse the repository at this point in the history
[FIX] Minor improvements to pythagorean trees
  • Loading branch information
thocevar authored May 24, 2019
2 parents d7b49f6 + 57d4ecc commit c90d9ce
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 50 deletions.
74 changes: 38 additions & 36 deletions Orange/widgets/visualize/owpythagorastree.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class Outputs:
graph_name = 'scene'

# Settings
settingsHandler = settings.DomainContextHandler()

depth_limit = settings.ContextSetting(10)
target_class_index = settings.ContextSetting(0)
size_calc_idx = settings.Setting(0)
Expand All @@ -73,8 +75,7 @@ def __init__(self):
super().__init__()
# Instance variables
self.model = None
self.instances = None
self.clf_dataset = None
self.data = None
# The tree adapter instance which is passed from the outside
self.tree_adapter = None
self.legend = None
Expand Down Expand Up @@ -147,18 +148,12 @@ def __init__(self):
@Inputs.tree
def set_tree(self, model=None):
"""When a different tree is given."""
self.closeContext()
self.clear()
self.model = model

if model is not None:
self.instances = model.instances
# this bit is important for the regression classifier
if self.instances is not None and \
self.instances.domain != model.domain:
self.clf_dataset = self.instances.transform(self.model.domain)
else:
self.clf_dataset = self.instances

self.data = model.instances
self.tree_adapter = self._get_tree_adapter(self.model)
self.ptree.clear()

Expand All @@ -177,30 +172,30 @@ def set_tree(self, model=None):

self._update_main_area()

# The target class can also be passed from the meta properties
# This must be set after `_update_target_class_combo`
if hasattr(model, 'meta_target_class_index'):
self.target_class_index = model.meta_target_class_index
self.update_colors()
self.openContext(self.model)

# Get meta variables describing what the settings should look like
# if the tree is passed from the Pythagorean forest widget.
if hasattr(model, 'meta_size_calc_idx'):
self.size_calc_idx = model.meta_size_calc_idx
self.update_size_calc()
self.update_depth()

# TODO There is still something wrong with this
# if hasattr(model, 'meta_depth_limit'):
# self.depth_limit = model.meta_depth_limit
# self.update_depth()
# The forest widget sets the following attributes on the tree,
# describing the settings on the forest widget. To keep the tree
# looking the same as on the forest widget, we prefer these settings to
# context settings, if set.
if hasattr(model, "meta_target_class_index"):
self.target_class_index = model.meta_target_class_index
self.update_colors()
if hasattr(model, "meta_size_calc_idx"):
self.size_calc_idx = model.meta_size_calc_idx
self.update_size_calc()
if hasattr(model, "meta_depth_limit"):
self.depth_limit = model.meta_depth_limit
self.update_depth()

self.Outputs.annotated_data.send(create_annotated_table(self.instances, None))
self.Outputs.annotated_data.send(create_annotated_table(self.data, None))

def clear(self):
"""Clear all relevant data from the widget."""
self.model = None
self.instances = None
self.clf_dataset = None
self.data = None
self.tree_adapter = None

if self.legend is not None:
Expand Down Expand Up @@ -228,6 +223,8 @@ def update_size_calc(self):
self.invalidate_tree()

def redraw(self):
if self.data is None:
return
self.tree_adapter.shuffle_children()
self.invalidate_tree()

Expand Down Expand Up @@ -307,16 +304,21 @@ def onDeleteWidget(self):

def commit(self):
"""Commit the selected data to output."""
if self.instances is None:
if self.data is None:
self.Outputs.selected_data.send(None)
self.Outputs.annotated_data.send(None)
return
nodes = [i.tree_node.label for i in self.scene.selectedItems()
if isinstance(i, SquareGraphicsItem)]

nodes = [
i.tree_node.label for i in self.scene.selectedItems()
if isinstance(i, SquareGraphicsItem)
]
data = self.tree_adapter.get_instances_in_nodes(nodes)
self.Outputs.selected_data.send(data)
selected_indices = self.tree_adapter.get_indices(nodes)
self.Outputs.annotated_data.send(create_annotated_table(self.instances, selected_indices))
self.Outputs.annotated_data.send(
create_annotated_table(self.data, selected_indices)
)

def send_report(self):
"""Send report."""
Expand All @@ -327,9 +329,9 @@ def _update_target_class_combo(self):
label = [x for x in self.target_class_combo.parent().children()
if isinstance(x, QLabel)][0]

if self.instances.domain.has_discrete_class:
if self.data.domain.has_discrete_class:
label_text = 'Target class'
values = [c.title() for c in self.instances.domain.class_vars[0].values]
values = [c.title() for c in self.data.domain.class_vars[0].values]
values.insert(0, 'None')
else:
label_text = 'Node color'
Expand All @@ -342,7 +344,7 @@ def _update_legend_colors(self):
if self.legend is not None:
self.scene.removeItem(self.legend)

if self.instances.domain.has_discrete_class:
if self.data.domain.has_discrete_class:
self._classification_update_legend_colors()
else:
self._regression_update_legend_colors()
Expand Down Expand Up @@ -375,14 +377,14 @@ def _get_colors_domain(domain):

# The colors are the class mean
if self.target_class_index == 1:
values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y))
values = (np.min(self.data.Y), np.max(self.data.Y))
colors = _get_colors_domain(self.model.domain)
while len(values) != len(colors):
values.insert(1, -1)
items = list(zip(values, colors))
# Colors are the stddev
elif self.target_class_index == 2:
values = (0, np.std(self.clf_dataset.Y))
values = (0, np.std(self.data.Y))
colors = _get_colors_domain(self.model.domain)
while len(values) != len(colors):
values.insert(1, -1)
Expand Down
32 changes: 19 additions & 13 deletions Orange/widgets/visualize/owpythagoreanforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Callable, Optional

from AnyQt.QtCore import Qt, QRectF, QSize, QPointF, QSizeF, QModelIndex, \
QItemSelection, QT_VERSION
QItemSelection, QItemSelectionModel, QT_VERSION
from AnyQt.QtGui import QPainter, QPen, QColor, QBrush, QMouseEvent
from AnyQt.QtWidgets import QSizePolicy, QGraphicsScene, QLabel, QSlider, \
QListView, QStyledItemDelegate, QStyleOptionViewItem, QStyle
Expand Down Expand Up @@ -174,11 +174,15 @@ class Outputs:
graph_name = 'scene'

# Settings
settingsHandler = settings.DomainContextHandler()

depth_limit = settings.ContextSetting(10)
target_class_index = settings.ContextSetting(0)
size_calc_idx = settings.Setting(0)
zoom = settings.Setting(200)

selected_index = settings.ContextSetting(None)

SIZE_CALCULATION = [
('Normal', lambda x: x),
('Square root', lambda x: sqrt(x)),
Expand All @@ -199,7 +203,6 @@ def __init__(self):
self.rf_model = None
self.forest = None
self.instances = None
self.clf_dataset = None

self.color_palette = None

Expand Down Expand Up @@ -265,29 +268,32 @@ def __init__(self):
@Inputs.random_forest
def set_rf(self, model=None):
"""When a different forest is given."""
self.closeContext()
self.clear()
self.rf_model = model

if model is not None:
self.forest = self._get_forest_adapter(self.rf_model)
self.forest_model[:] = self.forest.trees

self.instances = model.instances
# This bit is important for the regression classifier
if self.instances is not None and self.instances.domain != model.domain:
self.clf_dataset = self.instances.transform(self.rf_model.domain)
else:
self.clf_dataset = self.instances

self._update_info_box()
self._update_target_class_combo()
self._update_depth_slider()

self.openContext(model)
# Restore item selection
if self.selected_index is not None:
index = self.list_view.model().index(self.selected_index)
selection = QItemSelection(index, index)
self.list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)

def clear(self):
"""Clear all relevant data from the widget."""
self.rf_model = None
self.forest = None
self.forest_model.clear()
self.selected_index = None

self._clear_info_box()
self._clear_target_class_combo()
Expand Down Expand Up @@ -342,19 +348,19 @@ def onDeleteWidget(self):
super().onDeleteWidget()
self.clear()

def commit(self, selection):
# type: (QItemSelection) -> None
def commit(self, selection: QItemSelection) -> None:
"""Commit the selected tree to output."""
selected_indices = selection.indexes()

if not len(selected_indices):
self.selected_index = None
self.Outputs.tree.send(None)
return

selected_index, = selection.indexes()
# We only allow selecting a single tree so there will always be one index
self.selected_index = selected_indices[0].row()

idx = selected_index.row()
tree = self.rf_model.trees[idx]
tree = self.rf_model.trees[self.selected_index]
tree.instances = self.instances
tree.meta_target_class_index = self.target_class_index
tree.meta_size_calc_idx = self.size_calc_idx
Expand Down
12 changes: 12 additions & 0 deletions Orange/widgets/visualize/tests/test_owpythagorastree.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,15 @@ def test_forest_tree_table(self):
square.setSelected(True)
tab = self.get_output(tree_w.Outputs.selected_data, widget=tree_w)
self.assertGreater(len(tab), 0)

def test_changing_data_restores_depth_from_previous_settings(self):
titanic_data = Table("titanic")[::50]
forest = RandomForestLearner(n_estimators=3)(titanic_data)
forest.instances = titanic_data

self.send_signal(self.widget.Inputs.tree, forest.trees[0])
self.widget.controls.depth_limit.setValue(1)

# The domain is still the same, so restore the depth limit from before
self.send_signal(self.widget.Inputs.tree, forest.trees[1])
self.assertEqual(self.widget.ptree._depth_limit, 1)
22 changes: 21 additions & 1 deletion Orange/widgets/visualize/tests/test_owpythagoreanforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from unittest.mock import Mock

from AnyQt.QtCore import Qt
from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel

from Orange.classification.random_forest import RandomForestLearner
from Orange.data import Table
Expand Down Expand Up @@ -201,3 +201,23 @@ def _callback():
# Check that individual squares all have the same color
colors_same = [self._check_all_same(x) for x in zip(*colors)]
self.assertTrue(all(colors_same))

def select_tree(self, idx: int) -> None:
list_view = self.widget.list_view
index = list_view.model().index(idx)
selection = QItemSelection(index, index)
list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)

def test_storing_selection(self):
# Select one of the trees
idx = 1
self.send_signal(self.widget.Inputs.random_forest, self.titanic)
self.select_tree(idx)
# Clear input
self.send_signal(self.widget.Inputs.random_forest, None)
# Restore previous data; context settings should be restored
self.send_signal(self.widget.Inputs.random_forest, self.titanic)

output = self.get_output(self.widget.Outputs.tree)
self.assertIsNotNone(output)
self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model)

0 comments on commit c90d9ce

Please sign in to comment.