From a56c18d03ae8043fd21d1cb1ddfddf1b2689d4a3 Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 13 Oct 2017 16:04:01 +0200 Subject: [PATCH] OWScatterPlot: Add column with groups to selected data output --- Orange/widgets/utils/annotated_data.py | 25 +++++++------ Orange/widgets/visualize/owscatterplot.py | 36 +++++++++---------- .../visualize/tests/test_owscatterplot.py | 13 +++++++ 3 files changed, 45 insertions(+), 29 deletions(-) diff --git a/Orange/widgets/utils/annotated_data.py b/Orange/widgets/utils/annotated_data.py index 9416c50e954..c3e5adf2e4b 100644 --- a/Orange/widgets/utils/annotated_data.py +++ b/Orange/widgets/utils/annotated_data.py @@ -103,19 +103,22 @@ def create_annotated_table(data, selected_indices): return table -def create_groups_table(data, selection): +def create_groups_table(data, selection, + include_unselected=True, + var_name=ANNOTATED_DATA_FEATURE_NAME): if data is None: return None - names = [var.name for var in data.domain.variables + data.domain.metas] - name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME) - metas = data.domain.metas + ( - DiscreteVariable( - name, - ["Unselected"] + ["G{}".format(i + 1) - for i in range(np.max(selection))]), - ) + values = ["G{}".format(i + 1) for i in range(np.max(selection))] + if include_unselected: + values.insert(0, "Unselected") + else: + mask = np.flatnonzero(selection) + data = data[mask] + selection = selection[mask] - 1 + + var_name = get_next_name(data.domain, var_name) + metas = data.domain.metas + (DiscreteVariable(var_name, values), ) domain = Domain(data.domain.attributes, data.domain.class_vars, metas) table = data.transform(domain) - table.metas[:, len(data.domain.metas):] = \ - selection.reshape(len(data), 1) + table.metas[:, len(data.domain.metas):] = selection.reshape(len(data), 1) return table diff --git a/Orange/widgets/visualize/owscatterplot.py b/Orange/widgets/visualize/owscatterplot.py index 0a29ee556fd..15150fa70fd 100644 --- a/Orange/widgets/visualize/owscatterplot.py +++ b/Orange/widgets/visualize/owscatterplot.py @@ -21,9 +21,8 @@ from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph from Orange.widgets.visualize.utils import VizRankDialogAttrPair from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output -from Orange.widgets.utils.annotated_data import (create_annotated_table, - ANNOTATED_DATA_SIGNAL_NAME, - create_groups_table) +from Orange.widgets.utils.annotated_data import ( + create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME) class ScatterPlotVizRank(VizRankDialogAttrPair): @@ -428,25 +427,26 @@ def selection_changed(self): self.commit() def send_data(self): - selected = None - selection = None # TODO: Implement selection for sql data + def _get_selected(): + if not len(selection): + return None + return create_groups_table(data, graph.selection, False, "Group") + + def _get_annotated(): + if graph.selection is not None and np.max(graph.selection) > 1: + return create_groups_table(data, graph.selection) + else: + return create_annotated_table(data, selection) + graph = self.graph - if isinstance(self.data, SqlTable): - selected = self.data - elif self.data is not None: - selection = graph.get_selection() - if len(selection) > 0: - selected = self.data[selection] - if graph.selection is not None and np.max(graph.selection) > 1: - annotated = create_groups_table(self.data, graph.selection) - else: - annotated = create_annotated_table(self.data, selection) - self.Outputs.selected_data.send(selected) - self.Outputs.annotated_data.send(annotated) + data = self.data + selection = graph.get_selection() + self.Outputs.annotated_data.send(_get_annotated()) + self.Outputs.selected_data.send(_get_selected()) # Store current selection in a setting that is stored in workflow - if selection is not None and len(selection): + if len(selection): self.selection_group = list(zip(selection, graph.selection[selection])) else: self.selection_group = None diff --git a/Orange/widgets/visualize/tests/test_owscatterplot.py b/Orange/widgets/visualize/tests/test_owscatterplot.py index f00cf2c6e35..e84e201307e 100644 --- a/Orange/widgets/visualize/tests/test_owscatterplot.py +++ b/Orange/widgets/visualize/tests/test_owscatterplot.py @@ -18,6 +18,7 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin): def setUpClass(cls): super().setUpClass() WidgetOutputsTestMixin.init(cls) + cls.same_input_output_domain = False cls.signal_name = "Data" cls.signal_data = cls.data @@ -25,6 +26,9 @@ def setUpClass(cls): def setUp(self): self.widget = self.create_widget(OWScatterPlot) + def _compare_selected_annotated_domains(self, selected, annotated): + pass + def test_set_data(self): # Connect iris to scatter plot self.send_signal(self.widget.Inputs.data, self.data) @@ -154,6 +158,9 @@ def test_group_selections(self): def selectedx(): return self.get_output(self.widget.Outputs.selected_data).X + def selected_groups(): + return self.get_output(self.widget.Outputs.selected_data).metas[:, 0] + def annotated(): return self.get_output(self.widget.Outputs.annotated_data).metas @@ -163,6 +170,7 @@ def annotations(): # Select 0:5 graph.select(points[:5]) np.testing.assert_equal(selectedx(), x[:5]) + np.testing.assert_equal(selected_groups(), np.zeros(5)) sel_column[:5] = 1 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(annotations(), ["No", "Yes"]) @@ -171,6 +179,7 @@ def annotations(): with self.modifiers(Qt.ShiftModifier): graph.select(points[5:10]) np.testing.assert_equal(selectedx(), x[:10]) + np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5)) sel_column[5:10] = 2 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(len(annotations()), 3) @@ -180,12 +189,14 @@ def annotations(): sel_column = np.zeros((len(self.data), 1)) sel_column[15:20] = 1 np.testing.assert_equal(selectedx(), x[15:20]) + np.testing.assert_equal(selected_groups(), np.zeros(5)) self.assertEqual(annotations(), ["No", "Yes"]) # Alt-select (remove) 10:17; we have 17:20 with self.modifiers(Qt.AltModifier): graph.select(points[10:17]) np.testing.assert_equal(selectedx(), x[17:20]) + np.testing.assert_equal(selected_groups(), np.zeros(3)) sel_column[15:17] = 0 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(annotations(), ["No", "Yes"]) @@ -194,6 +205,7 @@ def annotations(): with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier): graph.select(points[20:25]) np.testing.assert_equal(selectedx(), x[17:25]) + np.testing.assert_equal(selected_groups(), np.zeros(8)) sel_column[20:25] = 1 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(annotations(), ["No", "Yes"]) @@ -205,6 +217,7 @@ def annotations(): with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier): graph.select(points[35:40]) sel_column[30:40] = 2 + np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10)) np.testing.assert_equal(annotated(), sel_column) self.assertEqual(len(annotations()), 3)