From f30e2439e6eaa94108daca3148132ec2a01e4330 Mon Sep 17 00:00:00 2001 From: janezd Date: Sun, 10 Jul 2022 16:04:39 +0200 Subject: [PATCH] Aggregate Columns: Add additional options for selection --- Orange/widgets/data/owaggregatecolumns.py | 203 ++++++++++++++---- .../data/tests/test_owaggregatecolumns.py | 149 ++++++++++++- 2 files changed, 303 insertions(+), 49 deletions(-) diff --git a/Orange/widgets/data/owaggregatecolumns.py b/Orange/widgets/data/owaggregatecolumns.py index d709977ee52..7e7979170f7 100644 --- a/Orange/widgets/data/owaggregatecolumns.py +++ b/Orange/widgets/data/owaggregatecolumns.py @@ -1,20 +1,30 @@ -from typing import List +from itertools import chain +from typing import List, NamedTuple, Callable import numpy as np -from AnyQt.QtWidgets import QSizePolicy +from AnyQt.QtWidgets import QSizePolicy, QStyle, \ + QButtonGroup, QRadioButton, QComboBox from AnyQt.QtCore import Qt + from Orange.data import Variable, Table, ContinuousVariable, TimeVariable from Orange.data.util import get_unique_names from Orange.widgets import gui, widget from Orange.widgets.settings import ( ContextSetting, Setting, DomainContextHandler ) +from Orange.widgets.utils.signals import AttributeList from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.widget import Input, Output from Orange.widgets.utils.itemmodels import DomainModel +class OpDesc(NamedTuple): + name: str + func: Callable[[np.ndarray], np.ndarray] + time_preserving: bool = False + + class OWAggregateColumns(widget.OWWidget): name = "Aggregate Columns" description = "Compute a sum, max, min ... of selected columns." @@ -26,53 +36,84 @@ class OWAggregateColumns(widget.OWWidget): class Inputs: data = Input("Data", Table, default=True) + features = Input("Features", AttributeList) class Outputs: data = Output("Data", Table) + class Warning(widget.OWWidget.Warning): + discrete_features = widget.Msg("Some input features are categorical:\n{}") + missing_features = widget.Msg("Some input features are missing:\n{}") + want_main_area = False + Operations = {"Sum": OpDesc("Sum", np.nansum), + "Product": OpDesc("Product", np.nanprod), + "Min": OpDesc("Minimal value", np.nanmin, True), + "Max": OpDesc("Maximal value", np.nanmax, True), + "Mean": OpDesc("Mean value", np.nanmean, True), + "Variance": OpDesc("Variance", np.nanvar), + "Median": OpDesc("Median", np.nanmedian, True)} + KeyFromDesc = {op.name: key for key, op in Operations.items()} + + SelectAll, SelectAllAndMeta, InputFeatures, SelectManually = range(4) + settingsHandler = DomainContextHandler() variables: List[Variable] = ContextSetting([]) - operation = Setting("Sum") - var_name = Setting("agg") + selection_method: int = Setting(SelectManually, schema_only=True) + operation = ContextSetting("Sum") + var_name = Setting("agg", schema_only=True) auto_apply = Setting(True) - Operations = {"Sum": np.nansum, "Product": np.nanprod, - "Min": np.nanmin, "Max": np.nanmax, - "Mean": np.nanmean, "Variance": np.nanvar, - "Median": np.nanmedian} - TimePreserving = ("Min", "Max", "Mean", "Median") - def __init__(self): super().__init__() self.data = None + self.features = None - box = gui.vBox(self.controlArea, box=True) + self.selection_box = gui.vBox(self.controlArea, "Variable selection") + self.selection_group = QButtonGroup(self.selection_box) + for i, label in enumerate(("All", + "All, including meta attributes", + "Features from separate input signal", + "Selected variables")): + button = QRadioButton(label) + if i == self.selection_method: + button.setChecked(True) + self.selection_group.addButton(button, id=i) + self.selection_box.layout().addWidget(button) + self.selection_group.idClicked.connect(self._on_sel_method_changed) self.variable_model = DomainModel( - order=DomainModel.MIXED, valid_types=(ContinuousVariable, )) + order=(DomainModel.ATTRIBUTES, DomainModel.METAS), + valid_types=ContinuousVariable) + pixm: QStyle = self.style().pixelMetric + ind_width = pixm(QStyle.PM_ExclusiveIndicatorWidth) + \ + pixm(QStyle.PM_RadioButtonLabelSpacing) var_list = gui.listView( - box, self, "variables", model=self.variable_model, + gui.indentedBox(self.selection_box, ind_width), self, "variables", + model=self.variable_model, callback=self.commit.deferred ) var_list.setSelectionMode(var_list.ExtendedSelection) - combo = gui.comboBox( - box, self, "operation", - label="Operator: ", orientation=Qt.Horizontal, - items=list(self.Operations), sendSelectedValue=True, - callback=self.commit.deferred - ) - combo.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed) + box = gui.vBox(self.controlArea, box="Operation") + combo = self.operation_combo = QComboBox() + combo.addItems([op.name for op in self.Operations.values()]) + combo.textActivated[str].connect(self._on_operation_changed) + combo.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Fixed) + combo.setCurrentText(self.Operations[self.operation].name) + box.layout().addWidget(combo) gui.lineEdit( box, self, "var_name", - label="Variable name: ", orientation=Qt.Horizontal, + label="Output variable name: ", orientation=Qt.Horizontal, callback=self.commit.deferred ) - gui.auto_apply(self.controlArea, self) + gui.auto_apply(self.buttonsArea, self) + + self._update_selection_buttons() + @Inputs.data def set_data(self, data: Table = None): @@ -82,56 +123,138 @@ def set_data(self, data: Table = None): if self.data: self.variable_model.set_domain(data.domain) self.openContext(data) + self.operation_combo.setCurrentText(self.Operations[self.operation].name) else: self.variable_model.set_domain(None) + + @Inputs.features + def set_features(self, features): + if features is None: + self.features = None + missing = [] + else: + self.features = [attr for attr in features if attr.is_continuous] + missing = self._missing(features, self.features) + self.Warning.discrete_features(missing, shown=bool(missing)) + + def _update_selection_buttons(self): + if self.features: + for i, button in enumerate(self.selection_group.buttons()): + button.setChecked(i == self.InputFeatures) + button.setEnabled(i == self.InputFeatures) + self.controls.variables.setEnabled(False) + else: + for i, button in enumerate(self.selection_group.buttons()): + button.setChecked(i == self.selection_method) + button.setEnabled(i != self.InputFeatures) + self.controls.variables.setEnabled( + self.selection_method == self.SelectManually) + + def handleNewSignals(self): + self._update_selection_buttons() self.commit.now() + def _on_sel_method_changed(self, i): + self.selection_method = i + self._update_selection_buttons() + self.commit.deferred() + + def _on_operation_changed(self, oper): + self.operation = self.KeyFromDesc[oper] + self.commit.deferred() + @gui.deferred def commit(self): augmented = self._compute_data() self.Outputs.data.send(augmented) def _compute_data(self): - if not self.data or not self.variables: + self.Warning.missing_features.clear() + if not self.data: + return self.data + + variables = self._variables() + if not self.data or not variables: return self.data - new_col = self._compute_column() - new_var = self._new_var() + new_col = self._compute_column(variables) + new_var = self._new_var(variables) return self.data.add_column(new_var, new_col) - def _compute_column(self): - arr = np.empty((len(self.data), len(self.variables))) - for i, var in enumerate(self.variables): + def _variables(self): + self.Warning.missing_features.clear() + if self.features: + selected = [attr for attr in self.features + if attr in self.data.domain] + missing = self._missing(self.features, selected) + self.Warning.missing_features(missing, shown=bool(missing)) + return selected + + assert self.data + + domain = self.data.domain + if self.selection_method == self.SelectAll: + return [attr for attr in domain.attributes + if attr.is_continuous] + if self.selection_method == self.SelectAllAndMeta: + # skip separators + return [attr for attr in chain(domain.attributes, domain.metas) + if attr.is_continuous] + + assert self.selection_method == self.SelectManually + return self.variables + + def _compute_column(self, variables): + arr = np.empty((len(self.data), len(variables))) + for i, var in enumerate(variables): arr[:, i] = self.data.get_column_view(var)[0].astype(float) - func = self.Operations[self.operation] + func = self.Operations[self.operation].func return func(arr, axis=1) def _new_var_name(self): return get_unique_names(self.data.domain, self.var_name) - def _new_var(self): + def _new_var(self, variables): name = self._new_var_name() - if self.operation in self.TimePreserving \ - and all(isinstance(var, TimeVariable) for var in self.variables): + if self.Operations[self.operation].time_preserving \ + and all(isinstance(var, TimeVariable) for var in variables): return TimeVariable(name) return ContinuousVariable(name) def send_report(self): - # fp for self.variables, pylint: disable=unsubscriptable-object - if not self.data or not self.variables: + if not self.data: return - var_list = ", ".join(f"'{var.name}'" - for var in self.variables[:31][:-1]) - if len(self.variables) > 30: - var_list += f" and {len(self.variables) - 30} others" - else: - var_list += f" and '{self.variables[-1].name}'" + variables = self._variables() + if not variables: + return + var_list = self._and_others(variables, 30) self.report_items(( ("Output:", f"'{self._new_var_name()}' as {self.operation.lower()} of {var_list}" ), )) + @staticmethod + def _and_others(variables, limit): + if len(variables) == 1: + return f"'{variables[0].name}'" + var_list = ", ".join(f"'{var.name}'" + for var in variables[:limit + 1][:-1]) + if len(variables) > limit: + var_list += f" and {len(variables) - limit} more" + else: + var_list += f" and '{variables[-1].name}'" + return var_list + + @classmethod + def _missing(cls, given, used): + if len(given) == len(used): + return "" + used = set(used) + # Don't use set difference because it loses order + missing = [attr for attr in given if attr not in used] + return cls._and_others(missing, 5) + if __name__ == "__main__": # pragma: no cover brown = Table("brown-selected") diff --git a/Orange/widgets/data/tests/test_owaggregatecolumns.py b/Orange/widgets/data/tests/test_owaggregatecolumns.py index f0f89eecabb..58ca793c737 100644 --- a/Orange/widgets/data/tests/test_owaggregatecolumns.py +++ b/Orange/widgets/data/tests/test_owaggregatecolumns.py @@ -13,6 +13,7 @@ ) from Orange.widgets.data.owaggregatecolumns import OWAggregateColumns from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.utils.signals import AttributeList class TestOWAggregateColumn(WidgetTest): @@ -69,7 +70,7 @@ def test_compute_data(self): def test_var_name(self): domain = self.data1.domain self.send_signal(self.widget.Inputs.data, self.data1) - self.widget.variables = self.widget.variable_model[:] + self.widget.selection_method = self.widget.SelectAllAndMeta self.widget.var_name = "test" output = self.widget._compute_data() @@ -84,15 +85,16 @@ def test_var_name(self): def test_var_types(self): domain = self.data1.domain self.send_signal(self.widget.Inputs.data, self.data1) + variables = [domain[n] for n in "t1 c2 t2".split()] - self.widget.variables = [domain[n] for n in "t1 c2 t2".split()] for self.widget.operation in self.widget.Operations: - self.assertIsInstance(self.widget._new_var(), ContinuousVariable) + self.assertIsInstance(self.widget._new_var(variables), + ContinuousVariable) - self.widget.variables = [domain[n] for n in "t1 t2".split()] + variables = [domain[n] for n in "t1 t2".split()] for self.widget.operation in self.widget.Operations: self.assertIsInstance( - self.widget._new_var(), + self.widget._new_var(variables), TimeVariable if self.widget.operation in ("Min", "Max", "Mean", "Median") else ContinuousVariable) @@ -100,7 +102,7 @@ def test_var_types(self): def test_operations(self): domain = self.data1.domain self.send_signal(self.widget.Inputs.data, self.data1) - self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] + variables = [domain[n] for n in "c1 c2 t2".split()] m1, m2 = 4 / 3, 8 / 3 for self.widget.operation, expected in { @@ -111,7 +113,7 @@ def test_operations(self): ((m2 - 3) ** 2 + (m2 - 1) ** 2 + (m2 - 4) ** 2) / 3], "Median": [1, 3]}.items(): np.testing.assert_equal( - self.widget._compute_column(), expected, + self.widget._compute_column(variables), expected, err_msg=f"error in '{self.widget.operation}'") def test_operations_with_nan(self): @@ -119,7 +121,7 @@ def test_operations_with_nan(self): self.send_signal(self.widget.Inputs.data, self.data1) with self.data1.unlocked(): self.data1.X[1, 0] = np.nan - self.widget.variables = [domain[n] for n in "c1 c2 t2".split()] + variables = [domain[n] for n in "c1 c2 t2".split()] m1, m2 = 4 / 3, 5 / 2 for self.widget.operation, expected in { @@ -130,7 +132,7 @@ def test_operations_with_nan(self): ((m2 - 1) ** 2 + (m2 - 4) ** 2) / 2], "Median": [1, 2.5]}.items(): np.testing.assert_equal( - self.widget._compute_column(), expected, + self.widget._compute_column(variables), expected, err_msg=f"error in '{self.widget.operation}'") def test_contexts(self): @@ -158,6 +160,135 @@ def test_selection_in_context(self): self.assertSequenceEqual(self.widget.variables[:], self.data1.domain.variables[1:3]) + def test_features_signal(self): + widget = self.widget + widget.selection_method = widget.SelectAll + self.send_signal(widget.Inputs.data, self.data1) + + self.assertEqual([attr.name for attr in widget._variables()], + "c1 c2 t1".split()) + + attr_list = [self.data1.domain[attr] for attr in "c1 t2".split()] + self.send_signal(widget.Inputs.features, AttributeList(attr_list)) + self.assertEqual(widget._variables(), attr_list) + self.assertFalse(widget.Warning.missing_features.is_shown()) + self.assertFalse(widget.Warning.discrete_features.is_shown()) + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 7]) + + attr_list = [self.data1.domain[attr] for attr in "c1 t2 d1".split()] + self.send_signal(widget.Inputs.features, AttributeList(attr_list)) + self.assertEqual(widget._variables(), attr_list[:2]) + self.assertFalse(widget.Warning.missing_features.is_shown()) + self.assertTrue(widget.Warning.discrete_features.is_shown()) + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 7]) + + attr_list.append(ContinuousVariable("foo")) + self.send_signal(widget.Inputs.features, AttributeList(attr_list)) + self.assertEqual(widget._variables(), attr_list[:2]) + self.assertTrue(widget.Warning.missing_features.is_shown()) + self.assertTrue(widget.Warning.discrete_features.is_shown()) + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 7]) + + self.send_signal(widget.Inputs.features, None) + self.assertFalse(widget.Warning.missing_features.is_shown()) + self.assertFalse(widget.Warning.discrete_features.is_shown()) + + del attr_list[2] # discrete variable + attr_list.append(ContinuousVariable("foo")) + self.send_signal(widget.Inputs.features, AttributeList(attr_list)) + self.assertEqual(widget._variables(), attr_list[:2]) + self.assertTrue(widget.Warning.missing_features.is_shown()) + self.assertFalse(widget.Warning.discrete_features.is_shown()) + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 7]) + + self.assertEqual(widget.selection_group.checkedId(), + widget.InputFeatures) + self.assertTrue(all( + button.isEnabled() is (i == widget.InputFeatures) + for i, button in enumerate(widget.selection_group.buttons()))) + self.assertFalse(widget.controls.variables.isEnabled()) + + self.send_signal(widget.Inputs.features, None) + self.assertEqual([attr.name for attr in widget._variables()], + "c1 c2 t1".split()) + self.assertEqual(widget.selection_group.checkedId(), widget.SelectAll) + self.assertTrue(all( + button.isEnabled() is (i != widget.InputFeatures) + for i, button in enumerate(widget.selection_group.buttons()))) + self.assertFalse(widget.controls.variables.isEnabled()) + + def test_selection_radios(self): + widget = self.widget + self.send_signal(widget.Inputs.data, self.data1) + widget.variables = [self.data1.domain[attr] for attr in "c1 t2".split()] + + widget.selection_group.button(widget.SelectAll).click() + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 46]) + + widget.selection_group.button(widget.SelectAllAndMeta).click() + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [6, 50]) + + widget.selection_group.button(widget.SelectManually).click() + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 7]) + + def test_operation_changed(self): + widget = self.widget + self.send_signal(widget.Inputs.data, self.data1) + widget.selection_group.button(widget.SelectAll).click() + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [3, 46]) + + oper = widget.Operations["Max"].name + widget.operation_combo.setCurrentText(oper) + widget.operation_combo.textActivated[str].emit(oper) + np.testing.assert_equal( + self.get_output(widget.Outputs.data).get_column_view("agg")[0], + [2, 42]) + + def test_and_others(self): + self.assertEqual( + self.widget._and_others(self.data1.domain.variables[:1], 1), + "'c1'") + self.assertEqual( + self.widget._and_others(self.data1.domain.variables[:1], 10), + "'c1'") + self.assertEqual( + self.widget._and_others(self.data1.domain.variables, 20), + "'c1', 'c2', 'd1', 'd2', 't1' and 'd3'") + self.assertEqual( + self.widget._and_others(self.data1.domain.variables, 6), + "'c1', 'c2', 'd1', 'd2', 't1' and 'd3'") + self.assertEqual( + self.widget._and_others(self.data1.domain.variables, 5), + "'c1', 'c2', 'd1', 'd2', 't1' and 1 more") + self.assertEqual( + self.widget._and_others(self.data1.domain.variables, 2), + "'c1', 'c2' and 4 more") + + def test_missing(self): + attrs = self.data1.domain.attributes + self.assertEqual(self.widget._missing(attrs, attrs), "") + + self.assertEqual(self.widget._missing(attrs, attrs[1:]), + f"'{attrs[0].name}'") + self.assertEqual(self.widget._missing(attrs, attrs[2:]), + f"'{attrs[0].name}' and '{attrs[1].name}'") + def test_report(self): self.widget.send_report()