Skip to content

Commit

Permalink
Merge pull request #2678 from janezd/scatterplot-selection-groups
Browse files Browse the repository at this point in the history
[ENH] Add Groups column to Selected Data in Scatter plot output
  • Loading branch information
jerneju authored Oct 20, 2017
2 parents 8cbe98a + b3c0a0d commit ea87577
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 38 deletions.
52 changes: 32 additions & 20 deletions Orange/widgets/utils/annotated_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from itertools import chain

import numpy as np
from Orange.data import Domain, DiscreteVariable
Expand Down Expand Up @@ -52,6 +53,11 @@ def get_next_name(names, name):
:param name: str
:return: str
"""
if isinstance(names, Domain):
names = [
var.name
for var in chain(names.attributes, names.class_vars, names.metas)
]
indexes = get_indices(names, name)
if name not in names and not indexes:
return name
Expand All @@ -74,6 +80,19 @@ def get_unique_names(names, proposed):
return proposed


def _table_with_annotation_column(data, values, column_data, var_name):
var = DiscreteVariable(get_next_name(data.domain, var_name), values)
class_vars, metas = data.domain.class_vars, data.domain.metas
if not data.domain.class_vars:
class_vars += (var, )
else:
metas += (var, )
domain = Domain(data.domain.attributes, class_vars, metas)
table = data.transform(domain)
table[:, var] = column_data.reshape((len(data), 1))
return table


def create_annotated_table(data, selected_indices):
"""
Returns data with concatenated flag column. Flag column represents
Expand All @@ -86,30 +105,23 @@ def create_annotated_table(data, selected_indices):
"""
if data is None:
return None
names = [var.name for var in data.domain.variables + data.domain.metas]
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
domain = add_columns(data.domain, metas=[DiscreteVariable(name, ("No", "Yes"))])
annotated = np.zeros((len(data), 1))
if selected_indices is not None:
annotated[selected_indices] = 1
table = data.transform(domain)
table[:, name] = annotated
return table
return _table_with_annotation_column(
data, ("No", "Yes"), annotated, ANNOTATED_DATA_FEATURE_NAME)


def create_groups_table(data, selection):
def create_groups_table(data, selection,
include_unselected=True,
var_name=ANNOTATED_DATA_FEATURE_NAME):
if data is None:
return None
names = [var.name for var in data.domain.variables + data.domain.metas]
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
metas = data.domain.metas + (
DiscreteVariable(
name,
["Unselected"] + ["G{}".format(i + 1)
for i in range(np.max(selection))]),
)
domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
table = data.transform(domain)
table.metas[:, len(data.domain.metas):] = \
selection.reshape(len(data), 1)
return table
values = ["G{}".format(i + 1) for i in range(np.max(selection))]
if include_unselected:
values.insert(0, "Unselected")
else:
mask = np.flatnonzero(selection)
data = data[mask]
selection = selection[mask] - 1
return _table_with_annotation_column(data, values, selection, var_name)
36 changes: 18 additions & 18 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output
from Orange.widgets.utils.annotated_data import (create_annotated_table,
ANNOTATED_DATA_SIGNAL_NAME,
create_groups_table)
from Orange.widgets.utils.annotated_data import (
create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME)


class ScatterPlotVizRank(VizRankDialogAttrPair):
Expand Down Expand Up @@ -427,25 +426,26 @@ def selection_changed(self):
self.commit()

def send_data(self):
selected = None
selection = None
# TODO: Implement selection for sql data
def _get_selected():
if not len(selection):
return None
return create_groups_table(data, graph.selection, False, "Group")

def _get_annotated():
if graph.selection is not None and np.max(graph.selection) > 1:
return create_groups_table(data, graph.selection)
else:
return create_annotated_table(data, selection)

graph = self.graph
if isinstance(self.data, SqlTable):
selected = self.data
elif self.data is not None:
selection = graph.get_selection()
if len(selection) > 0:
selected = self.data[selection]
if graph.selection is not None and np.max(graph.selection) > 1:
annotated = create_groups_table(self.data, graph.selection)
else:
annotated = create_annotated_table(self.data, selection)
self.Outputs.selected_data.send(selected)
self.Outputs.annotated_data.send(annotated)
data = self.data
selection = graph.get_selection()
self.Outputs.annotated_data.send(_get_annotated())
self.Outputs.selected_data.send(_get_selected())

# Store current selection in a setting that is stored in workflow
if selection is not None and len(selection):
if len(selection):
self.selection_group = list(zip(selection, graph.selection[selection]))
else:
self.selection_group = None
Expand Down
15 changes: 15 additions & 0 deletions Orange/widgets/visualize/tests/test_owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin):
def setUpClass(cls):
super().setUpClass()
WidgetOutputsTestMixin.init(cls)
cls.same_input_output_domain = False

cls.signal_name = "Data"
cls.signal_data = cls.data

def setUp(self):
self.widget = self.create_widget(OWScatterPlot)

def _compare_selected_annotated_domains(self, selected, annotated):
# Base class tests that selected.domain is a subset of annotated.domain
# In scatter plot, the two domains are unrelated, so we disable the test
pass

def test_set_data(self):
# Connect iris to scatter plot
self.send_signal(self.widget.Inputs.data, self.data)
Expand Down Expand Up @@ -154,6 +160,9 @@ def test_group_selections(self):
def selectedx():
return self.get_output(self.widget.Outputs.selected_data).X

def selected_groups():
return self.get_output(self.widget.Outputs.selected_data).metas[:, 0]

def annotated():
return self.get_output(self.widget.Outputs.annotated_data).metas

Expand All @@ -163,6 +172,7 @@ def annotations():
# Select 0:5
graph.select(points[:5])
np.testing.assert_equal(selectedx(), x[:5])
np.testing.assert_equal(selected_groups(), np.zeros(5))
sel_column[:5] = 1
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(annotations(), ["No", "Yes"])
Expand All @@ -171,6 +181,7 @@ def annotations():
with self.modifiers(Qt.ShiftModifier):
graph.select(points[5:10])
np.testing.assert_equal(selectedx(), x[:10])
np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5))
sel_column[5:10] = 2
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(len(annotations()), 3)
Expand All @@ -180,12 +191,14 @@ def annotations():
sel_column = np.zeros((len(self.data), 1))
sel_column[15:20] = 1
np.testing.assert_equal(selectedx(), x[15:20])
np.testing.assert_equal(selected_groups(), np.zeros(5))
self.assertEqual(annotations(), ["No", "Yes"])

# Alt-select (remove) 10:17; we have 17:20
with self.modifiers(Qt.AltModifier):
graph.select(points[10:17])
np.testing.assert_equal(selectedx(), x[17:20])
np.testing.assert_equal(selected_groups(), np.zeros(3))
sel_column[15:17] = 0
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(annotations(), ["No", "Yes"])
Expand All @@ -194,6 +207,7 @@ def annotations():
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
graph.select(points[20:25])
np.testing.assert_equal(selectedx(), x[17:25])
np.testing.assert_equal(selected_groups(), np.zeros(8))
sel_column[20:25] = 1
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(annotations(), ["No", "Yes"])
Expand All @@ -205,6 +219,7 @@ def annotations():
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
graph.select(points[35:40])
sel_column[30:40] = 2
np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10))
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(len(annotations()), 3)

Expand Down

0 comments on commit ea87577

Please sign in to comment.