Skip to content

Commit

Permalink
Merge pull request #5963 from janezd/silhouette-contexts
Browse files Browse the repository at this point in the history
Silhouette: Use variables instead of indices
  • Loading branch information
VesnaT authored May 20, 2022
2 parents 1b044ca + 3617003 commit 25ba55a
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 41 deletions.
98 changes: 57 additions & 41 deletions Orange/widgets/visualize/owsilhouetteplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ class Outputs:
"Orange.widgets.unsupervised.owsilhouetteplot.OWSilhouettePlot"
]

settingsHandler = settings.PerfectDomainContextHandler()
settingsHandler = settings.DomainContextHandler()
settings_version = 2

#: Distance metric index
distance_idx = settings.Setting(0)
#: Group/cluster variable index
cluster_var_idx = settings.ContextSetting(0)
#: Annotation variable index
annotation_var_idx = settings.ContextSetting(0)
#: Group/cluster variable
cluster_var = settings.ContextSetting(None)
#: Annotation variable
annotation_var = settings.ContextSetting(None)
#: Group the (displayed) silhouettes by cluster
group_by_cluster = settings.Setting(True)
#: A fixed size for an instance bar
Expand Down Expand Up @@ -145,29 +146,30 @@ def __init__(self):
orientation=Qt.Horizontal, callback=self._invalidate_distances)
controllayout.addWidget(distbox)

box = gui.vBox(self.controlArea, "Cluster Label")
box = gui.vBox(self.controlArea, "Grouping")
self.cluster_var_model = itemmodels.VariableListModel(
parent=self, placeholder="(None)")
self.cluster_var_cb = gui.comboBox(
box, self, "cluster_var_idx", contentsLength=14,
searchable=True, callback=self._invalidate_scores
box, self, "cluster_var", contentsLength=14,
searchable=True, callback=self._invalidate_scores,
model=self.cluster_var_model
)
gui.checkBox(
box, self, "group_by_cluster", "Group by cluster",
box, self, "group_by_cluster", "Show in groups",
callback=self._replot)
self.cluster_var_model = itemmodels.VariableListModel(parent=self)
self.cluster_var_cb.setModel(self.cluster_var_model)

box = gui.vBox(self.controlArea, "Bars")
gui.widgetLabel(box, "Bar width:")
gui.hSlider(
box, self, "bar_size", minValue=1, maxValue=10, step=1,
callback=self._update_bar_size)
gui.widgetLabel(box, "Annotations:")
self.annotation_cb = gui.comboBox(
box, self, "annotation_var_idx", contentsLength=14,
callback=self._update_annotations)
self.annotation_var_model = itemmodels.VariableListModel(parent=self)
self.annotation_var_model[:] = ["None"]
self.annotation_cb.setModel(self.annotation_var_model)
self.annotation_var_model[:] = [None]
self.annotation_cb = gui.comboBox(
box, self, "annotation_var", contentsLength=14,
callback=self._update_annotations,
model=self.annotation_var_model)
ibox = gui.indentedBox(box, 5)
self.ann_hidden_warning = warning = gui.widgetLabel(
ibox, "(increase the width to show)")
Expand Down Expand Up @@ -258,14 +260,14 @@ def _setup_control_models(self, domain: Domain):
raise NoGroupVariable()
self.cluster_var_model[:] = groupvars
if domain.class_var in groupvars:
self.cluster_var_idx = groupvars.index(domain.class_var)
self.cluster_var = domain.class_var
else:
self.cluster_var_idx = 0
self.cluster_var = groupvars[0]
annotvars = [var for var in domain.variables + domain.metas
if var.is_string or var.is_discrete]
self.annotation_var_model[:] = ["None"] + annotvars
self.annotation_var_idx = 1 if annotvars else 0
self.openContext(Orange.data.Domain(groupvars))
self.annotation_var_model[:] = [None] + annotvars
self.annotation_var = annotvars[0] if annotvars else None
self.openContext(domain)

def _is_empty(self) -> bool:
# Is empty (does not have any input).
Expand All @@ -283,7 +285,7 @@ def clear(self):
self._silhouette = None
self._labels = None
self.cluster_var_model[:] = []
self.annotation_var_model[:] = ["None"]
self.annotation_var_model[:] = [None]
self._clear_scene()
self.Error.clear()
self.Warning.clear()
Expand Down Expand Up @@ -346,8 +348,7 @@ def _update(self):
if self._matrix is None:
return

labelvar = self.cluster_var_model[self.cluster_var_idx]
labels, _ = self.data.get_column_view(labelvar)
labels, _ = self.data.get_column_view(self.cluster_var)
labels = np.asarray(labels, dtype=float)
cluster_mask = np.isnan(labels)
dist_mask = np.isnan(self._matrix).all(axis=0)
Expand Down Expand Up @@ -396,19 +397,19 @@ def _set_bar_height(self):
self._silplot.setBarHeight(self.bar_size)
self._silplot.setRowNamesVisible(visible)
self.ann_hidden_warning.setVisible(
not visible and self.annotation_var_idx > 0)
not visible and self.annotation_var is not None)

def _replot(self):
# Clear and replot/initialize the scene
self._clear_scene()
if self._silhouette is not None and self._labels is not None:
var = self.cluster_var_model[self.cluster_var_idx]
self._silplot = silplot = SilhouettePlot()
self._set_bar_height()

if self.group_by_cluster:
silplot.setScores(self._silhouette, self._labels, var.values,
var.colors)
silplot.setScores(
self._silhouette, self._labels,
self.cluster_var.values, self.cluster_var.colors)
else:
silplot.setScores(
self._silhouette,
Expand All @@ -428,10 +429,7 @@ def _update_bar_size(self):
self._set_bar_height()

def _update_annotations(self):
if 0 < self.annotation_var_idx < len(self.annotation_var_model):
annot_var = self.annotation_var_model[self.annotation_var_idx]
else:
annot_var = None
annot_var = self.annotation_var
self.ann_hidden_warning.setVisible(
self.bar_size < 5 and annot_var is not None)

Expand Down Expand Up @@ -494,10 +492,8 @@ def commit(self):
else:
scores = self._silhouette

var = self.cluster_var_model[self.cluster_var_idx]

domain = self.data.domain
proposed = "Silhouette ({})".format(escape(var.name))
proposed = "Silhouette ({})".format(escape(self.cluster_var.name))
names = [var.name for var in itertools.chain(domain.attributes,
domain.class_vars,
domain.metas)]
Expand Down Expand Up @@ -528,18 +524,38 @@ def send_report(self):
return

self.report_plot()
caption = "Silhouette plot ({} distance), clustered by '{}'".format(
self.Distances[self.distance_idx][0],
self.cluster_var_model[self.cluster_var_idx])
if self.annotation_var_idx and self._silplot.rowNamesVisible():
caption += ", annotated with '{}'".format(
self.annotation_var_model[self.annotation_var_idx])
caption = "Silhouette plot " \
f"({self.Distances[self.distance_idx][0]} distance), " \
f"clustered by '{self.cluster_var.name}'"
if self.annotation_var and self._silplot.rowNamesVisible():
caption += f", annotated with '{self.annotation_var.name}'"
self.report_caption(caption)

def onDeleteWidget(self):
self.clear()
super().onDeleteWidget()

@classmethod
def migrate_context(cls, context, version):
values = context.values
if version < 2:
# contexts were constructed from Domain containing vars shown in
# the list view, context.class_vars and context.metas were always
# empty, and context.attributes contained discrete attributes
index, _ = values.pop("cluster_var_idx")
values["cluster_var"] = (context.attributes[index][0], 101)

index = values.pop("annotation_var_idx")[0] - 1
if index == -1:
values["annotation_var"] = None
elif index < len(context.attributes):
name, _ = context.attributes[index]
values["annotation_var"] = (name, 101)
# else we cannot migrate
# Even this migration can be erroneous if metas contained a mixture
# of discrete and string attributes; the latter were not stored in
# context, so indices in context could have been wrong


class SelectAction(enum.IntEnum):
NoUpdate, Clear, Select, Deselect, Toogle, Current = 1, 2, 4, 8, 16, 32
Expand Down
69 changes: 69 additions & 0 deletions Orange/widgets/visualize/tests/test_owsilhouetteplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
# pylint: disable=missing-docstring
import random
import unittest
from unittest.mock import Mock

import numpy as np

from orangewidget.settings import Context

import Orange.distance
from Orange.data import (
Table, Domain, ContinuousVariable, DiscreteVariable, StringVariable)
Expand Down Expand Up @@ -226,6 +229,72 @@ def test_unique_output_domain(self):
output = self.get_output(self.widget.Outputs.annotated_data)
self.assertEqual(output.domain.metas[0].name, 'Silhouette (iris) (1)')

def test_report(self):
widget = self.widget
widget.report_plot = Mock()
widget.report_caption = Mock()

widget.send_report()
widget.report_plot.assert_not_called()
widget.report_caption.assert_not_called()

data = Table("zoo")
self.send_signal(widget.Inputs.data, data)

widget.annotation_var = None
widget.send_report()
widget.report_plot.assert_called()
widget.report_caption.assert_called()
text = widget.report_caption.call_args[0][0]
self.assertIn(data.domain.class_var.name, text)
self.assertNotIn("nnotated", text)

widget.annotation_var = data.domain.metas[0]
widget._silplot.rowNamesVisible = lambda: True
widget.send_report()
text = widget.report_caption.call_args[0][0]
self.assertIn(data.domain.class_var.name, text)
self.assertIn("nnotated", text)
self.assertIn(data.domain.metas[0].name, text)

def test_migration(self):
enc_domain = dict(
attributes=(('foo', 1), ('bar', 1), ('baz', 1), ('bax', 1),
('cfoo', 1), ('mbaz', 1)))

# No annotation
context = Context(
values=dict(cluster_var_idx=(0, -2), annotation_var_idx=(0, -2)),
**enc_domain
)
OWSilhouettePlot.migrate_context(context, 1)
values = context.values
self.assertNotIn("cluster_var_idx", values)
self.assertNotIn("annotation_var_idx", values)
self.assertEqual(values["cluster_var"], ("foo", 101))
self.assertEqual(values["annotation_var"], None)

# We have annotation
context = Context(
values=dict(cluster_var_idx=(2, -2), annotation_var_idx=(4, -2)),
**enc_domain
)
OWSilhouettePlot.migrate_context(context, 1)
self.assertNotIn("cluster_var_idx", values)
self.assertNotIn("annotation_var_idx", values)
self.assertEqual(context.values["cluster_var"], ("baz", 101))
self.assertEqual(context.values["annotation_var"], ("bax", 101))

# We thought was had annotation, but the index is wrong due to
# incorrect domain
context = Context(
values=dict(cluster_var_idx=(4, -2), annotation_var_idx=(7, -2)),
**enc_domain
)
OWSilhouettePlot.migrate_context(context, 1)
self.assertEqual(context.values["cluster_var"], ("cfoo", 101))
self.assertNotIn("annotation_var_idx", values)


if __name__ == "__main__":
unittest.main()

0 comments on commit 25ba55a

Please sign in to comment.