diff --git a/Orange/widgets/data/owselectrows.py b/Orange/widgets/data/owselectrows.py index 4022330dee5..a07876d6461 100644 --- a/Orange/widgets/data/owselectrows.py +++ b/Orange/widgets/data/owselectrows.py @@ -1,6 +1,5 @@ import enum from collections import OrderedDict -from itertools import chain import numpy as np @@ -13,16 +12,16 @@ QFontMetrics, QPalette ) from AnyQt.QtCore import Qt, QPoint, QRegExp, QPersistentModelIndex, QLocale + +from Orange.widgets.utils.itemmodels import DomainModel from orangewidget.utils.combobox import ComboBoxSearch from Orange.data import ( - Variable, ContinuousVariable, DiscreteVariable, StringVariable, - TimeVariable, + ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable, Table ) import Orange.data.filter as data_filter from Orange.data.filter import FilterContinuous, FilterString -from Orange.data.domain import filter_visible from Orange.data.sql.table import SqlTable from Orange.preprocess import Remove from Orange.widgets import widget, gui @@ -52,24 +51,47 @@ def encode_setting(self, context, setting, value): encoded = [] CONTINUOUS = vartype(ContinuousVariable("x")) for attr, op, values in value: - vtype = context.attributes.get(attr) - if vtype == CONTINUOUS and values and isinstance(values[0], str): - values = [QLocale().toDouble(v)[0] for v in values] - encoded.append((attr, vtype, op, values)) + if isinstance(attr, str): + if OWSelectRows.AllTypes.get(attr) == CONTINUOUS: + values = [QLocale().toDouble(v)[0] for v in values] + # None will match the value returned by all_vars.get + encoded.append((attr, None, op, values)) + else: + if type(attr) is ContinuousVariable \ + and values and isinstance(values[0], str): + values = [QLocale().toDouble(v)[0] for v in values] + elif isinstance(attr, DiscreteVariable): + values = [attr.values[i - 1] if i else "" for i in values] + encoded.append( + (attr.name, context.attributes.get(attr.name), op, values)) return encoded def decode_setting(self, setting, value, domain=None): value = super().decode_setting(setting, value, domain) if setting.name == 'conditions': + CONTINUOUS = vartype(ContinuousVariable("x")) # Use this after 2022/2/2: - # for i, (attr, _, op, values) in enumerate(value): - for i, condition in enumerate(value): - attr = condition[0] - op, values = condition[-2:] - - var = attr in domain and domain[attr] - if var and var.is_continuous and not isinstance(var, TimeVariable): + # for i, (attr, tpe, op, values) in enumerate(value): + # if tpe is not None: + for i, (attr, *tpe, op, values) in enumerate(value): + if tpe != (None, ) \ + or not tpe and attr not in OWSelectRows.AllTypes: + attr = domain[attr] + if type(attr) is ContinuousVariable \ + or OWSelectRows.AllTypes.get(attr) == CONTINUOUS: values = [QLocale().toString(float(i), 'f') for i in values] + elif isinstance(attr, DiscreteVariable): + # After 2022/2/2, use just the expression in else clause + if values and isinstance(values[0], int): + # Backwards compatibility. Reset setting if we detect + # that the number of values decreased. Still broken if + # they're reordered or we don't detect the decrease. + if max(values) > len(attr.values): + values = [0] + else: + values = [attr.to_val(val) + 1 if val else 0 + for val in values if val in attr.values] \ + or [0] value[i] = (attr, op, values) return value @@ -80,7 +102,7 @@ def match(self, context, domain, attrs, metas): conditions = context.values["conditions"] all_vars = attrs.copy() all_vars.update(metas) - matched = [all_vars.get(name) == tpe + matched = [all_vars.get(name) == tpe # also matches "all (...)" strings # After 2022/2/2 remove this line: if len(rest) == 2 else name in all_vars for name, tpe, *rest in conditions] @@ -101,6 +123,8 @@ def filter_value(self, setting, data, domain, attrs, metas): # if all_vars.get(name) == tpe] conditions[:] = [ (name, tpe, *rest) for name, tpe, *rest in conditions + # all_vars.get(name) == tpe also matches "all (...)" which are + # encoded with type `None` if (all_vars.get(name) == tpe if len(rest) == 2 else name in all_vars)] @@ -209,6 +233,9 @@ def __init__(self): self.last_output_conditions = None self.data = None self.data_desc = self.match_desc = self.nonmatch_desc = None + self.variable_model = DomainModel( + [list(self.AllTypes), DomainModel.Separator, + DomainModel.CLASSES, DomainModel.ATTRIBUTES, DomainModel.METAS]) box = gui.vBox(self.controlArea, 'Conditions', stretch=100) self.cond_list = QTableWidget( @@ -268,18 +295,11 @@ def add_row(self, attr=None, condition_type=None, condition_value=None): attr_combo = ComboBoxSearch( minimumContentsLength=12, sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon) + attr_combo.setModel(self.variable_model) attr_combo.row = row - for var in self._visible_variables(self.data.domain): - if isinstance(var, Variable): - attr_combo.addItem(*gui.attributeItem(var)) - else: - attr_combo.addItem(var) - if isinstance(attr, str): - attr_combo.setCurrentText(attr) - else: - attr_combo.setCurrentIndex( - attr or - len(self.AllTypes) - (attr_combo.count() == len(self.AllTypes))) + attr_combo.setCurrentIndex(self.variable_model.indexOf(attr) if attr + else len(self.AllTypes) + 1) + self.cond_list.setCellWidget(row, 0, attr_combo) index = QPersistentModelIndex(model.index(row, 3)) @@ -297,15 +317,6 @@ def add_row(self, attr=None, condition_type=None, condition_value=None): self.cond_list.resizeRowToContents(row) - @classmethod - def _visible_variables(cls, domain): - """Generate variables in order they should be presented in in combos.""" - return chain( - cls.AllTypes, - filter_visible(chain(domain.class_vars, - domain.metas, - domain.attributes))) - def add_all(self): if self.cond_list.rowCount(): Mb = QMessageBox @@ -315,9 +326,8 @@ def add_all(self): "filters for all variables.", Mb.Ok | Mb.Cancel) != Mb.Ok: return self.remove_all() - domain = self.data.domain - for i in range(len(domain.variables) + len(domain.metas)): - self.add_row(i) + for attr in self.variable_model[len(self.AllTypes) + 1:]: + self.add_row(attr) def remove_one(self, rownum): self.remove_one_row(rownum) @@ -333,6 +343,12 @@ def remove_one_row(self, rownum): self.remove_all_button.setDisabled(True) def remove_all_rows(self): + # Disconnect signals to avoid stray emits when changing variable_model + for row in range(self.cond_list.rowCount()): + for col in (0, 1): + widget = self.cond_list.cellWidget(row, col) + if widget: + widget.currentIndexChanged.disconnect() self.cond_list.clear() self.cond_list.setRowCount(0) self.remove_all_button.setDisabled(True) @@ -495,24 +511,18 @@ def set_data(self, data): if not data: self.info.set_input_summary(self.info.NoInput) self.data_desc = None + self.variable_model.set_domain(None) self.commit() return self.data_desc = report.describe_data_brief(data) - self.conditions = [] - try: - self.openContext(data) - except Exception: - pass + self.variable_model.set_domain(data.domain) - variables = list(self._visible_variables(self.data.domain)) - varnames = [v.name if isinstance(v, Variable) else v for v in variables] - if self.conditions: - for attr, cond_type, cond_value in self.conditions: - if attr in varnames: - self.add_row(varnames.index(attr), cond_type, cond_value) - elif attr in self.AllTypes: - self.add_row(attr, cond_type, cond_value) - else: + self.conditions = [] + self.openContext(data) + for attr, cond_type, cond_value in self.conditions: + if attr in self.variable_model: + self.add_row(attr, cond_type, cond_value) + if not self.cond_list.model().rowCount(): self.add_row() self.info.set_input_summary(data.approx_len(), @@ -521,12 +531,15 @@ def set_data(self, data): def conditions_changed(self): try: - self.conditions = [] + cells_by_rows = ( + [self.cond_list.cellWidget(row, col) for col in range(3)] + for row in range(self.cond_list.rowCount()) + ) self.conditions = [ - (self.cond_list.cellWidget(row, 0).currentText(), - self.cond_list.cellWidget(row, 1).currentIndex(), - self._get_value_contents(self.cond_list.cellWidget(row, 2))) - for row in range(self.cond_list.rowCount())] + (var_cell.currentData(gui.TableVariable) or var_cell.currentText(), + oper_cell.currentIndex(), + self._get_value_contents(val_cell)) + for var_cell, oper_cell, val_cell in cells_by_rows] if self.update_on_change and ( self.last_output_conditions is None or self.last_output_conditions != self.conditions): @@ -674,19 +687,18 @@ def send_report(self): pdesc = ndesc conditions = [] - domain = self.data.domain - for attr_name, oper, values in self.conditions: - if attr_name in self.AllTypes: - attr = attr_name + for attr, oper, values in self.conditions: + if isinstance(attr, str): + attr_name = attr + var_type = self.AllTypes[attr] names = self.operator_names[attr_name] - var_type = self.AllTypes[attr_name] else: - attr = domain[attr_name] + attr_name = attr.name var_type = vartype(attr) names = self.operator_names[type(attr)] name = names[oper] if oper == len(names) - 1: - conditions.append("{} {}".format(attr, name)) + conditions.append("{} {}".format(attr_name, name)) elif var_type == 1: # discrete if name == "is one of": valnames = [attr.values[v - 1] for v in values] diff --git a/Orange/widgets/data/tests/test_owselectrows.py b/Orange/widgets/data/tests/test_owselectrows.py index df0d6b4d9dd..f190eb2acef 100644 --- a/Orange/widgets/data/tests/test_owselectrows.py +++ b/Orange/widgets/data/tests/test_owselectrows.py @@ -70,7 +70,7 @@ def test_filter_cont(self): for i, (op, _) in enumerate(OWSelectRows.Operators[ContinuousVariable]): self.widget.remove_all() - self.widget.add_row(1, i, CFValues[op]) + self.widget.add_row(iris.domain[0], i, CFValues[op]) self.widget.conditions_changed() self.widget.unconditional_commit() @@ -80,7 +80,7 @@ def test_filter_str(self): self.widget.set_data(zoo) for i, (op, _) in enumerate(OWSelectRows.Operators[StringVariable]): self.widget.remove_all() - self.widget.add_row(1, i, SFValues[op]) + self.widget.add_row(zoo.domain.metas[0], i, SFValues[op]) self.widget.conditions_changed() self.widget.unconditional_commit() @@ -183,11 +183,11 @@ def test_partial_matches(self): iris = Table("iris") domain = iris.domain self.widget = self.widget_with_context( - domain, [[domain[0].name, 2, ("5.2",)]]) + domain, [[domain[0].name, 2, 2, ("5.2",)]]) iris2 = iris.transform(Domain(domain.attributes[:2], None)) self.send_signal(self.widget.Inputs.data, iris2) condition = self.widget.conditions[0] - self.assertEqual(condition[0], "sepal length") + self.assertEqual(condition[0], iris.domain[0]) self.assertEqual(condition[1], 2) self.assertTrue(condition[2][0].startswith("5.2")) @@ -196,12 +196,12 @@ def test_partial_matches_with_missing_vars(self): iris = Table("iris") domain = iris.domain self.widget = self.widget_with_context( - domain, [[domain[0].name, 2, ("5.2",)], - [domain[2].name, 2, ("4.2",)]]) + domain, [[domain[0].name, 2, 2, ("5.2",)], + [domain[2].name, 2, 2, ("4.2",)]]) iris2 = iris.transform(Domain(domain.attributes[2:], None)) self.send_signal(self.widget.Inputs.data, iris2) condition = self.widget.conditions[0] - self.assertEqual(condition[0], domain[2].name) + self.assertEqual(condition[0], domain[2]) self.assertEqual(condition[1], 2) self.assertTrue(condition[2][0].startswith("4.2")) @@ -354,7 +354,7 @@ def test_support_old_settings(self): iris.domain, [["sepal length", 2, ("5.2",)]]) self.send_signal(self.widget.Inputs.data, iris) condition = self.widget.conditions[0] - self.assertEqual(condition[0], "sepal length") + self.assertEqual(condition[0], iris.domain["sepal length"]) self.assertEqual(condition[1], 2) self.assertTrue(condition[2][0].startswith("5.2")) @@ -380,7 +380,7 @@ def test_purge_discretized(self): discretize_class=True, method=method) domain = discretizer(housing) data = housing.transform(domain) - widget = self.widget_with_context(domain, [["MEDV", 2, (2, 3)]]) + widget = self.widget_with_context(domain, [["MEDV", 101, 2, (2, 3)]]) widget.purge_classes = True self.send_signal(widget.Inputs.data, data) out = self.get_output(widget.Outputs.matching_data)