From 962f4c28b23c5b4e97c8633afc17fbfc8da93358 Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 13 Oct 2017 16:03:25 +0200 Subject: [PATCH 1/3] annotated_data.get_next_name: Accept domain as an argument --- Orange/widgets/utils/annotated_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Orange/widgets/utils/annotated_data.py b/Orange/widgets/utils/annotated_data.py index 6057973c4ed..9416c50e954 100644 --- a/Orange/widgets/utils/annotated_data.py +++ b/Orange/widgets/utils/annotated_data.py @@ -1,4 +1,5 @@ import re +from itertools import chain import numpy as np from Orange.data import Domain, DiscreteVariable @@ -52,6 +53,11 @@ def get_next_name(names, name): :param name: str :return: str """ + if isinstance(names, Domain): + names = [ + var.name + for var in chain(names.attributes, names.class_vars, names.metas) + ] indexes = get_indices(names, name) if name not in names and not indexes: return name From e5f5128b39e664cc88712e249b9a2202fd997b2d Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 13 Oct 2017 16:04:01 +0200 Subject: [PATCH 2/3] OWScatterPlot: Add column with groups to selected data output --- Orange/widgets/utils/annotated_data.py | 25 +++++++------ Orange/widgets/visualize/owscatterplot.py | 36 +++++++++---------- .../visualize/tests/test_owscatterplot.py | 15 ++++++++ 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/Orange/widgets/utils/annotated_data.py b/Orange/widgets/utils/annotated_data.py index 9416c50e954..c3e5adf2e4b 100644 --- a/Orange/widgets/utils/annotated_data.py +++ b/Orange/widgets/utils/annotated_data.py @@ -103,19 +103,22 @@ def create_annotated_table(data, selected_indices): return table -def create_groups_table(data, selection): +def create_groups_table(data, selection, + include_unselected=True, + var_name=ANNOTATED_DATA_FEATURE_NAME): if data is None: return None - names = [var.name for var in data.domain.variables + data.domain.metas] - name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME) - metas = data.domain.metas + ( - DiscreteVariable( - name, - ["Unselected"] + ["G{}".format(i + 1) - for i in range(np.max(selection))]), - ) + values = ["G{}".format(i + 1) for i in range(np.max(selection))] + if include_unselected: + values.insert(0, "Unselected") + else: + mask = np.flatnonzero(selection) + data = data[mask] + selection = selection[mask] - 1 + + var_name = get_next_name(data.domain, var_name) + metas = data.domain.metas + (DiscreteVariable(var_name, values), ) domain = Domain(data.domain.attributes, data.domain.class_vars, metas) table = data.transform(domain) - table.metas[:, len(data.domain.metas):] = \ - selection.reshape(len(data), 1) + table.metas[:, len(data.domain.metas):] = selection.reshape(len(data), 1) return table diff --git a/Orange/widgets/visualize/owscatterplot.py b/Orange/widgets/visualize/owscatterplot.py index 0a29ee556fd..15150fa70fd 100644 --- a/Orange/widgets/visualize/owscatterplot.py +++ b/Orange/widgets/visualize/owscatterplot.py @@ -21,9 +21,8 @@ from Orange.widgets.visualize.owscatterplotgraph import OWScatterPlotGraph from Orange.widgets.visualize.utils import VizRankDialogAttrPair from Orange.widgets.widget import OWWidget, AttributeList, Msg, Input, Output -from Orange.widgets.utils.annotated_data import (create_annotated_table, - ANNOTATED_DATA_SIGNAL_NAME, - create_groups_table) +from Orange.widgets.utils.annotated_data import ( + create_annotated_table, create_groups_table, ANNOTATED_DATA_SIGNAL_NAME) class ScatterPlotVizRank(VizRankDialogAttrPair): @@ -428,25 +427,26 @@ def selection_changed(self): self.commit() def send_data(self): - selected = None - selection = None # TODO: Implement selection for sql data + def _get_selected(): + if not len(selection): + return None + return create_groups_table(data, graph.selection, False, "Group") + + def _get_annotated(): + if graph.selection is not None and np.max(graph.selection) > 1: + return create_groups_table(data, graph.selection) + else: + return create_annotated_table(data, selection) + graph = self.graph - if isinstance(self.data, SqlTable): - selected = self.data - elif self.data is not None: - selection = graph.get_selection() - if len(selection) > 0: - selected = self.data[selection] - if graph.selection is not None and np.max(graph.selection) > 1: - annotated = create_groups_table(self.data, graph.selection) - else: - annotated = create_annotated_table(self.data, selection) - self.Outputs.selected_data.send(selected) - self.Outputs.annotated_data.send(annotated) + data = self.data + selection = graph.get_selection() + self.Outputs.annotated_data.send(_get_annotated()) + self.Outputs.selected_data.send(_get_selected()) # Store current selection in a setting that is stored in workflow - if selection is not None and len(selection): + if len(selection): self.selection_group = list(zip(selection, graph.selection[selection])) else: self.selection_group = None diff --git a/Orange/widgets/visualize/tests/test_owscatterplot.py b/Orange/widgets/visualize/tests/test_owscatterplot.py index f00cf2c6e35..3d737eeed13 100644 --- a/Orange/widgets/visualize/tests/test_owscatterplot.py +++ b/Orange/widgets/visualize/tests/test_owscatterplot.py @@ -18,6 +18,7 @@ class TestOWScatterPlot(WidgetTest, WidgetOutputsTestMixin): def setUpClass(cls): super().setUpClass() WidgetOutputsTestMixin.init(cls) + cls.same_input_output_domain = False cls.signal_name = "Data" cls.signal_data = cls.data @@ -25,6 +26,11 @@ def setUpClass(cls): def setUp(self): self.widget = self.create_widget(OWScatterPlot) + def _compare_selected_annotated_domains(self, selected, annotated): + # Base class tests that selected.domain is a subset of annotated.domain + # In scatter plot, the two domains are unrelated, so we disable the test + pass + def test_set_data(self): # Connect iris to scatter plot self.send_signal(self.widget.Inputs.data, self.data) @@ -154,6 +160,9 @@ def test_group_selections(self): def selectedx(): return self.get_output(self.widget.Outputs.selected_data).X + def selected_groups(): + return self.get_output(self.widget.Outputs.selected_data).metas[:, 0] + def annotated(): return self.get_output(self.widget.Outputs.annotated_data).metas @@ -163,6 +172,7 @@ def annotations(): # Select 0:5 graph.select(points[:5]) np.testing.assert_equal(selectedx(), x[:5]) + np.testing.assert_equal(selected_groups(), np.zeros(5)) sel_column[:5] = 1 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(annotations(), ["No", "Yes"]) @@ -171,6 +181,7 @@ def annotations(): with self.modifiers(Qt.ShiftModifier): graph.select(points[5:10]) np.testing.assert_equal(selectedx(), x[:10]) + np.testing.assert_equal(selected_groups(), np.array([0] * 5 + [1] * 5)) sel_column[5:10] = 2 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(len(annotations()), 3) @@ -180,12 +191,14 @@ def annotations(): sel_column = np.zeros((len(self.data), 1)) sel_column[15:20] = 1 np.testing.assert_equal(selectedx(), x[15:20]) + np.testing.assert_equal(selected_groups(), np.zeros(5)) self.assertEqual(annotations(), ["No", "Yes"]) # Alt-select (remove) 10:17; we have 17:20 with self.modifiers(Qt.AltModifier): graph.select(points[10:17]) np.testing.assert_equal(selectedx(), x[17:20]) + np.testing.assert_equal(selected_groups(), np.zeros(3)) sel_column[15:17] = 0 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(annotations(), ["No", "Yes"]) @@ -194,6 +207,7 @@ def annotations(): with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier): graph.select(points[20:25]) np.testing.assert_equal(selectedx(), x[17:25]) + np.testing.assert_equal(selected_groups(), np.zeros(8)) sel_column[20:25] = 1 np.testing.assert_equal(annotated(), sel_column) self.assertEqual(annotations(), ["No", "Yes"]) @@ -205,6 +219,7 @@ def annotations(): with self.modifiers(Qt.ShiftModifier | Qt.ControlModifier): graph.select(points[35:40]) sel_column[30:40] = 2 + np.testing.assert_equal(selected_groups(), np.array([0] * 8 + [1] * 10)) np.testing.assert_equal(annotated(), sel_column) self.assertEqual(len(annotations()), 3) From b3c0a0dc72318abc3d8eb9a545ddb083f4a3e138 Mon Sep 17 00:00:00 2001 From: janezd Date: Mon, 16 Oct 2017 22:07:03 +0200 Subject: [PATCH 3/3] create_(groups|annotated)_table: Add new column as class if the data has none --- Orange/widgets/utils/annotated_data.py | 29 ++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/Orange/widgets/utils/annotated_data.py b/Orange/widgets/utils/annotated_data.py index c3e5adf2e4b..c1a4f5865da 100644 --- a/Orange/widgets/utils/annotated_data.py +++ b/Orange/widgets/utils/annotated_data.py @@ -80,6 +80,19 @@ def get_unique_names(names, proposed): return proposed +def _table_with_annotation_column(data, values, column_data, var_name): + var = DiscreteVariable(get_next_name(data.domain, var_name), values) + class_vars, metas = data.domain.class_vars, data.domain.metas + if not data.domain.class_vars: + class_vars += (var, ) + else: + metas += (var, ) + domain = Domain(data.domain.attributes, class_vars, metas) + table = data.transform(domain) + table[:, var] = column_data.reshape((len(data), 1)) + return table + + def create_annotated_table(data, selected_indices): """ Returns data with concatenated flag column. Flag column represents @@ -92,15 +105,11 @@ def create_annotated_table(data, selected_indices): """ if data is None: return None - names = [var.name for var in data.domain.variables + data.domain.metas] - name = get_next_name(names, ANNOTATED_DATA_FEATURE_NAME) - domain = add_columns(data.domain, metas=[DiscreteVariable(name, ("No", "Yes"))]) annotated = np.zeros((len(data), 1)) if selected_indices is not None: annotated[selected_indices] = 1 - table = data.transform(domain) - table[:, name] = annotated - return table + return _table_with_annotation_column( + data, ("No", "Yes"), annotated, ANNOTATED_DATA_FEATURE_NAME) def create_groups_table(data, selection, @@ -115,10 +124,4 @@ def create_groups_table(data, selection, mask = np.flatnonzero(selection) data = data[mask] selection = selection[mask] - 1 - - var_name = get_next_name(data.domain, var_name) - metas = data.domain.metas + (DiscreteVariable(var_name, values), ) - domain = Domain(data.domain.attributes, data.domain.class_vars, metas) - table = data.transform(domain) - table.metas[:, len(data.domain.metas):] = selection.reshape(len(data), 1) - return table + return _table_with_annotation_column(data, values, selection, var_name)