Skip to content

Commit

Permalink
Select Rows: Fix incorrectly stored values in settings
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed May 21, 2020
1 parent 56a106f commit 783013c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 74 deletions.
142 changes: 77 additions & 65 deletions Orange/widgets/data/owselectrows.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
from collections import OrderedDict
from itertools import chain

import numpy as np

Expand All @@ -13,16 +12,16 @@
QFontMetrics, QPalette
)
from AnyQt.QtCore import Qt, QPoint, QRegExp, QPersistentModelIndex, QLocale

from Orange.widgets.utils.itemmodels import DomainModel
from orangewidget.utils.combobox import ComboBoxSearch

from Orange.data import (
Variable, ContinuousVariable, DiscreteVariable, StringVariable,
TimeVariable,
ContinuousVariable, DiscreteVariable, StringVariable, TimeVariable,
Table
)
import Orange.data.filter as data_filter
from Orange.data.filter import FilterContinuous, FilterString
from Orange.data.domain import filter_visible
from Orange.data.sql.table import SqlTable
from Orange.preprocess import Remove
from Orange.widgets import widget, gui
Expand Down Expand Up @@ -52,24 +51,47 @@ def encode_setting(self, context, setting, value):
encoded = []
CONTINUOUS = vartype(ContinuousVariable("x"))
for attr, op, values in value:
vtype = context.attributes.get(attr)
if vtype == CONTINUOUS and values and isinstance(values[0], str):
values = [QLocale().toDouble(v)[0] for v in values]
encoded.append((attr, vtype, op, values))
if isinstance(attr, str):
if OWSelectRows.AllTypes.get(attr) == CONTINUOUS:
values = [QLocale().toDouble(v)[0] for v in values]
# None will match the value returned by all_vars.get
encoded.append((attr, None, op, values))
else:
if type(attr) is ContinuousVariable \
and values and isinstance(values[0], str):
values = [QLocale().toDouble(v)[0] for v in values]
elif isinstance(attr, DiscreteVariable):
values = [attr.values[i - 1] if i else "" for i in values]
encoded.append(
(attr.name, context.attributes.get(attr.name), op, values))
return encoded

def decode_setting(self, setting, value, domain=None):
value = super().decode_setting(setting, value, domain)
if setting.name == 'conditions':
CONTINUOUS = vartype(ContinuousVariable("x"))
# Use this after 2022/2/2:
# for i, (attr, _, op, values) in enumerate(value):
for i, condition in enumerate(value):
attr = condition[0]
op, values = condition[-2:]

var = attr in domain and domain[attr]
if var and var.is_continuous and not isinstance(var, TimeVariable):
# for i, (attr, tpe, op, values) in enumerate(value):
# if tpe is not None:
for i, (attr, *tpe, op, values) in enumerate(value):
if tpe != (None, ) \
or not tpe and attr not in OWSelectRows.AllTypes:
attr = domain[attr]
if type(attr) is ContinuousVariable \
or OWSelectRows.AllTypes.get(attr) == CONTINUOUS:
values = [QLocale().toString(float(i), 'f') for i in values]
elif isinstance(attr, DiscreteVariable):
# After 2022/2/2, use just the expression in else clause
if values and isinstance(values[0], int):
# Backwards compatibility. Reset setting if we detect
# that the number of values decreased. Still broken if
# they're reordered or we don't detect the decrease.
if max(values) > len(attr.values):
values = [0]
else:
values = [attr.to_val(val) + 1 if val else 0
for val in values if val in attr.values] \
or [0]
value[i] = (attr, op, values)
return value

Expand All @@ -80,7 +102,7 @@ def match(self, context, domain, attrs, metas):
conditions = context.values["conditions"]
all_vars = attrs.copy()
all_vars.update(metas)
matched = [all_vars.get(name) == tpe
matched = [all_vars.get(name) == tpe # also matches "all (...)" strings
# After 2022/2/2 remove this line:
if len(rest) == 2 else name in all_vars
for name, tpe, *rest in conditions]
Expand All @@ -101,6 +123,8 @@ def filter_value(self, setting, data, domain, attrs, metas):
# if all_vars.get(name) == tpe]
conditions[:] = [
(name, tpe, *rest) for name, tpe, *rest in conditions
# all_vars.get(name) == tpe also matches "all (...)" which are
# encoded with type `None`
if (all_vars.get(name) == tpe if len(rest) == 2
else name in all_vars)]

Expand Down Expand Up @@ -209,6 +233,9 @@ def __init__(self):
self.last_output_conditions = None
self.data = None
self.data_desc = self.match_desc = self.nonmatch_desc = None
self.variable_model = DomainModel(
[list(self.AllTypes), DomainModel.Separator,
DomainModel.CLASSES, DomainModel.ATTRIBUTES, DomainModel.METAS])

box = gui.vBox(self.controlArea, 'Conditions', stretch=100)
self.cond_list = QTableWidget(
Expand Down Expand Up @@ -268,18 +295,11 @@ def add_row(self, attr=None, condition_type=None, condition_value=None):
attr_combo = ComboBoxSearch(
minimumContentsLength=12,
sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon)
attr_combo.setModel(self.variable_model)
attr_combo.row = row
for var in self._visible_variables(self.data.domain):
if isinstance(var, Variable):
attr_combo.addItem(*gui.attributeItem(var))
else:
attr_combo.addItem(var)
if isinstance(attr, str):
attr_combo.setCurrentText(attr)
else:
attr_combo.setCurrentIndex(
attr or
len(self.AllTypes) - (attr_combo.count() == len(self.AllTypes)))
attr_combo.setCurrentIndex(self.variable_model.indexOf(attr) if attr
else len(self.AllTypes) + 1)

self.cond_list.setCellWidget(row, 0, attr_combo)

index = QPersistentModelIndex(model.index(row, 3))
Expand All @@ -297,15 +317,6 @@ def add_row(self, attr=None, condition_type=None, condition_value=None):

self.cond_list.resizeRowToContents(row)

@classmethod
def _visible_variables(cls, domain):
"""Generate variables in order they should be presented in in combos."""
return chain(
cls.AllTypes,
filter_visible(chain(domain.class_vars,
domain.metas,
domain.attributes)))

def add_all(self):
if self.cond_list.rowCount():
Mb = QMessageBox
Expand All @@ -315,9 +326,8 @@ def add_all(self):
"filters for all variables.", Mb.Ok | Mb.Cancel) != Mb.Ok:
return
self.remove_all()
domain = self.data.domain
for i in range(len(domain.variables) + len(domain.metas)):
self.add_row(i)
for attr in self.variable_model[len(self.AllTypes) + 1:]:
self.add_row(attr)

def remove_one(self, rownum):
self.remove_one_row(rownum)
Expand All @@ -333,6 +343,12 @@ def remove_one_row(self, rownum):
self.remove_all_button.setDisabled(True)

def remove_all_rows(self):
# Disconnect signals to avoid stray emits when changing variable_model
for row in range(self.cond_list.rowCount()):
for col in (0, 1):
widget = self.cond_list.cellWidget(row, col)
if widget:
widget.currentIndexChanged.disconnect()
self.cond_list.clear()
self.cond_list.setRowCount(0)
self.remove_all_button.setDisabled(True)
Expand Down Expand Up @@ -495,24 +511,18 @@ def set_data(self, data):
if not data:
self.info.set_input_summary(self.info.NoInput)
self.data_desc = None
self.variable_model.set_domain(None)
self.commit()
return
self.data_desc = report.describe_data_brief(data)
self.conditions = []
try:
self.openContext(data)
except Exception:
pass
self.variable_model.set_domain(data.domain)

variables = list(self._visible_variables(self.data.domain))
varnames = [v.name if isinstance(v, Variable) else v for v in variables]
if self.conditions:
for attr, cond_type, cond_value in self.conditions:
if attr in varnames:
self.add_row(varnames.index(attr), cond_type, cond_value)
elif attr in self.AllTypes:
self.add_row(attr, cond_type, cond_value)
else:
self.conditions = []
self.openContext(data)
for attr, cond_type, cond_value in self.conditions:
if attr in self.variable_model:
self.add_row(attr, cond_type, cond_value)
if not self.cond_list.model().rowCount():
self.add_row()

self.info.set_input_summary(data.approx_len(),
Expand All @@ -521,12 +531,15 @@ def set_data(self, data):

def conditions_changed(self):
try:
self.conditions = []
cells_by_rows = (
[self.cond_list.cellWidget(row, col) for col in range(3)]
for row in range(self.cond_list.rowCount())
)
self.conditions = [
(self.cond_list.cellWidget(row, 0).currentText(),
self.cond_list.cellWidget(row, 1).currentIndex(),
self._get_value_contents(self.cond_list.cellWidget(row, 2)))
for row in range(self.cond_list.rowCount())]
(var_cell.currentData(gui.TableVariable) or var_cell.currentText(),
oper_cell.currentIndex(),
self._get_value_contents(val_cell))
for var_cell, oper_cell, val_cell in cells_by_rows]
if self.update_on_change and (
self.last_output_conditions is None or
self.last_output_conditions != self.conditions):
Expand Down Expand Up @@ -674,19 +687,18 @@ def send_report(self):
pdesc = ndesc

conditions = []
domain = self.data.domain
for attr_name, oper, values in self.conditions:
if attr_name in self.AllTypes:
attr = attr_name
for attr, oper, values in self.conditions:
if isinstance(attr, str):
attr_name = attr
var_type = self.AllTypes[attr]
names = self.operator_names[attr_name]
var_type = self.AllTypes[attr_name]
else:
attr = domain[attr_name]
attr_name = attr.name
var_type = vartype(attr)
names = self.operator_names[type(attr)]
name = names[oper]
if oper == len(names) - 1:
conditions.append("{} {}".format(attr, name))
conditions.append("{} {}".format(attr_name, name))
elif var_type == 1: # discrete
if name == "is one of":
valnames = [attr.values[v - 1] for v in values]
Expand Down
18 changes: 9 additions & 9 deletions Orange/widgets/data/tests/test_owselectrows.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_filter_cont(self):

for i, (op, _) in enumerate(OWSelectRows.Operators[ContinuousVariable]):
self.widget.remove_all()
self.widget.add_row(1, i, CFValues[op])
self.widget.add_row(iris.domain[0], i, CFValues[op])
self.widget.conditions_changed()
self.widget.unconditional_commit()

Expand All @@ -80,7 +80,7 @@ def test_filter_str(self):
self.widget.set_data(zoo)
for i, (op, _) in enumerate(OWSelectRows.Operators[StringVariable]):
self.widget.remove_all()
self.widget.add_row(1, i, SFValues[op])
self.widget.add_row(zoo.domain.metas[0], i, SFValues[op])
self.widget.conditions_changed()
self.widget.unconditional_commit()

Expand Down Expand Up @@ -183,11 +183,11 @@ def test_partial_matches(self):
iris = Table("iris")
domain = iris.domain
self.widget = self.widget_with_context(
domain, [[domain[0].name, 2, ("5.2",)]])
domain, [[domain[0].name, 2, 2, ("5.2",)]])
iris2 = iris.transform(Domain(domain.attributes[:2], None))
self.send_signal(self.widget.Inputs.data, iris2)
condition = self.widget.conditions[0]
self.assertEqual(condition[0], "sepal length")
self.assertEqual(condition[0], iris.domain[0])
self.assertEqual(condition[1], 2)
self.assertTrue(condition[2][0].startswith("5.2"))

Expand All @@ -196,12 +196,12 @@ def test_partial_matches_with_missing_vars(self):
iris = Table("iris")
domain = iris.domain
self.widget = self.widget_with_context(
domain, [[domain[0].name, 2, ("5.2",)],
[domain[2].name, 2, ("4.2",)]])
domain, [[domain[0].name, 2, 2, ("5.2",)],
[domain[2].name, 2, 2, ("4.2",)]])
iris2 = iris.transform(Domain(domain.attributes[2:], None))
self.send_signal(self.widget.Inputs.data, iris2)
condition = self.widget.conditions[0]
self.assertEqual(condition[0], domain[2].name)
self.assertEqual(condition[0], domain[2])
self.assertEqual(condition[1], 2)
self.assertTrue(condition[2][0].startswith("4.2"))

Expand Down Expand Up @@ -354,7 +354,7 @@ def test_support_old_settings(self):
iris.domain, [["sepal length", 2, ("5.2",)]])
self.send_signal(self.widget.Inputs.data, iris)
condition = self.widget.conditions[0]
self.assertEqual(condition[0], "sepal length")
self.assertEqual(condition[0], iris.domain["sepal length"])
self.assertEqual(condition[1], 2)
self.assertTrue(condition[2][0].startswith("5.2"))

Expand All @@ -380,7 +380,7 @@ def test_purge_discretized(self):
discretize_class=True, method=method)
domain = discretizer(housing)
data = housing.transform(domain)
widget = self.widget_with_context(domain, [["MEDV", 2, (2, 3)]])
widget = self.widget_with_context(domain, [["MEDV", 101, 2, (2, 3)]])
widget.purge_classes = True
self.send_signal(widget.Inputs.data, data)
out = self.get_output(widget.Outputs.matching_data)
Expand Down

0 comments on commit 783013c

Please sign in to comment.