Skip to content

Commit

Permalink
OWScatterPlot: Add column with groups to selected data output
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Oct 13, 2017
1 parent 6afb4cf commit a56c18d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 29 deletions.
25 changes: 14 additions & 11 deletions Orange/widgets/utils/annotated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,22 @@ def create_annotated_table(data, selected_indices):
return table


def create_groups_table(data, selection):
def create_groups_table(data, selection,
include_unselected=True,
var_name=ANNOTATED_DATA_FEATURE_NAME):
if data is None:
return None
names = [var.name for var in data.domain.variables + data.domain.metas]
name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME)
metas = data.domain.metas + (
DiscreteVariable(
name,
["Unselected"] + ["G{}".format(i + 1)
for i in range(np.max(selection))]),
)
values = ["G{}".format(i + 1) for i in range(np.max(selection))]
if include_unselected:
values.insert(0, "Unselected")
else:
mask = np.flatnonzero(selection)
data = data[mask]
selection = selection[mask] - 1

var_name = get_next_name(data.domain, var_name)
metas = data.domain.metas + (DiscreteVariable(var_name, values), )
domain = Domain(data.domain.attributes, data.domain.class_vars, metas)
table = data.transform(domain)
table.metas[:, len(data.domain.metas):] = \
selection.reshape(len(data), 1)
table.metas[:, len(data.domain.metas):] = selection.reshape(len(data), 1)
return table
36 changes: 18 additions & 18 deletions Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph
from Orange.widgets.visualize.utils import VizRankDialogAttrPair
from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output
from Orange.widgets.utils.annotated_data import (create_annotated_table,
ANNOTATED_DATA_SIGNAL_NAME,
create_groups_table)
from Orange.widgets.utils.annotated_data import (
create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME)


class ScatterPlotVizRank(VizRankDialogAttrPair):
Expand Down Expand Up @@ -428,25 +427,26 @@ def selection_changed(self):
self.commit()

def send_data(self):
selected = None
selection = None
# TODO: Implement selection for sql data
def _get_selected():
if not len(selection):
return None
return create_groups_table(data, graph.selection, False, "Group")

def _get_annotated():
if graph.selection is not None and np.max(graph.selection) > 1:
return create_groups_table(data, graph.selection)
else:
return create_annotated_table(data, selection)

graph = self.graph
if isinstance(self.data, SqlTable):
selected = self.data
elif self.data is not None:
selection = graph.get_selection()
if len(selection) > 0:
selected = self.data[selection]
if graph.selection is not None and np.max(graph.selection) > 1:
annotated = create_groups_table(self.data, graph.selection)
else:
annotated = create_annotated_table(self.data, selection)
self.Outputs.selected_data.send(selected)
self.Outputs.annotated_data.send(annotated)
data = self.data
selection = graph.get_selection()
self.Outputs.annotated_data.send(_get_annotated())
self.Outputs.selected_data.send(_get_selected())

# Store current selection in a setting that is stored in workflow
if selection is not None and len(selection):
if len(selection):
self.selection_group = list(zip(selection, graph.selection[selection]))
else:
self.selection_group = None
Expand Down
13 changes: 13 additions & 0 deletions Orange/widgets/visualize/tests/test_owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin):
def setUpClass(cls):
super().setUpClass()
WidgetOutputsTestMixin.init(cls)
cls.same_input_output_domain = False

cls.signal_name = "Data"
cls.signal_data = cls.data

def setUp(self):
self.widget = self.create_widget(OWScatterPlot)

def _compare_selected_annotated_domains(self, selected, annotated):
pass

def test_set_data(self):
# Connect iris to scatter plot
self.send_signal(self.widget.Inputs.data, self.data)
Expand Down Expand Up @@ -154,6 +158,9 @@ def test_group_selections(self):
def selectedx():
return self.get_output(self.widget.Outputs.selected_data).X

def selected_groups():
return self.get_output(self.widget.Outputs.selected_data).metas[:, 0]

def annotated():
return self.get_output(self.widget.Outputs.annotated_data).metas

Expand All @@ -163,6 +170,7 @@ def annotations():
# Select 0:5
graph.select(points[:5])
np.testing.assert_equal(selectedx(), x[:5])
np.testing.assert_equal(selected_groups(), np.zeros(5))
sel_column[:5] = 1
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(annotations(), ["No", "Yes"])
Expand All @@ -171,6 +179,7 @@ def annotations():
with self.modifiers(Qt.ShiftModifier):
graph.select(points[5:10])
np.testing.assert_equal(selectedx(), x[:10])
np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5))
sel_column[5:10] = 2
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(len(annotations()), 3)
Expand All @@ -180,12 +189,14 @@ def annotations():
sel_column = np.zeros((len(self.data), 1))
sel_column[15:20] = 1
np.testing.assert_equal(selectedx(), x[15:20])
np.testing.assert_equal(selected_groups(), np.zeros(5))
self.assertEqual(annotations(), ["No", "Yes"])

# Alt-select (remove) 10:17; we have 17:20
with self.modifiers(Qt.AltModifier):
graph.select(points[10:17])
np.testing.assert_equal(selectedx(), x[17:20])
np.testing.assert_equal(selected_groups(), np.zeros(3))
sel_column[15:17] = 0
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(annotations(), ["No", "Yes"])
Expand All @@ -194,6 +205,7 @@ def annotations():
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
graph.select(points[20:25])
np.testing.assert_equal(selectedx(), x[17:25])
np.testing.assert_equal(selected_groups(), np.zeros(8))
sel_column[20:25] = 1
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(annotations(), ["No", "Yes"])
Expand All @@ -205,6 +217,7 @@ def annotations():
with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier):
graph.select(points[35:40])
sel_column[30:40] = 2
np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10))
np.testing.assert_equal(annotated(), sel_column)
self.assertEqual(len(annotations()), 3)

Expand Down

0 comments on commit a56c18d

Please sign in to comment.