From 00522349321146231ed7509c8f6f55e0df17c1dc Mon Sep 17 00:00:00 2001 From: janezd Date: Fri, 20 Sep 2019 20:48:03 +0200 Subject: [PATCH] OWColor: Refactor to avoid setting Variable.name --- Orange/widgets/data/owcolor.py | 171 +++++++++++++++++---------------- 1 file changed, 87 insertions(+), 84 deletions(-) diff --git a/Orange/widgets/data/owcolor.py b/Orange/widgets/data/owcolor.py index fa015fce96f..8215286558a 100644 --- a/Orange/widgets/data/owcolor.py +++ b/Orange/widgets/data/owcolor.py @@ -1,6 +1,7 @@ """ Widget for assigning colors to variables """ +from itertools import chain import numpy as np from AnyQt.QtCore import Qt, QSize, QAbstractTableModel @@ -18,6 +19,20 @@ ColorRole = next(gui.OrangeUserRole) +class AttrDesc: + def __init__(self, var, new_name, colors, values=None): + self.var = var + self.new_name = new_name + self.colors = colors + self.values = values + + @classmethod + def from_var(cls, var): + return cls(var, var.name, None, var.values if var.is_discrete else None) + + def get_colors(self): + return self.colors or self.var.colors + # noinspection PyMethodOverriding class ColorTableModel(QAbstractTableModel): """Base color model for discrete and continuous attributes. The model @@ -55,7 +70,7 @@ def data(self, index, role=Qt.DisplayRole): # Only valid for the first column row = index.row() if role in (Qt.DisplayRole, Qt.EditRole): - return self.variables[row].name + return self.variables[row].new_name if role == Qt.FontRole: font = QFont() font.setBold(True) @@ -68,7 +83,7 @@ def setData(self, index, value, role): # pylint: disable=missing-docstring # Only valid for the first column if role == Qt.EditRole: - self.variables[index.row()].name = value + self.variables[index.row()].new_name = value else: return False self.dataChanged.emit(index, index) @@ -83,28 +98,26 @@ class DiscColorTableModel(ColorTableModel): # pylint: disable=missing-docstring def n_columns(self): return bool(self.variables) and \ - 1 + max(len(var.values) for var in self.variables) + 1 + max(len(row.values) for row in self.variables) def data(self, index, role=Qt.DisplayRole): # pylint: disable=too-many-return-statements row, col = index.row(), index.column() if col == 0: return ColorTableModel.data(self, index, role) - var = self.variables[row] - if col > len(var.values): + values = self.variables[row].values + if col > len(values): return None if role in (Qt.DisplayRole, Qt.EditRole): - return var.values[col - 1] - try: - color = var.colors[col - 1] - except (AttributeError, IndexError): - return None + return values[col - 1] + desc = self.variables[row] + color = desc.get_colors()[col - 1] if role == Qt.DecorationRole: return QColor(*color) if role == Qt.ToolTipRole: return self._encode_color(color) if role == ColorRole: - return var.colors[col - 1] + return color return None # noinspection PyMethodOverriding @@ -112,10 +125,13 @@ def setData(self, index, value, role): row, col = index.row(), index.column() if col == 0: return ColorTableModel.setData(self, index, value, role) + desc = self.variables[row] if role == ColorRole: - self.variables[row].set_color(col - 1, value[:3]) + if not desc.colors: + desc.colors = desc.var.colors.tolist() + desc.colors[col - 1] = value[:3] elif role == Qt.EditRole: - self.variables[row].values[col - 1] = value + desc.values[col - 1] = value else: return False self.dataChanged.emit(index, index) @@ -139,7 +155,8 @@ def _column0(): def _column1(): if role == Qt.DecorationRole: - continuous_palette = ContinuousPaletteGenerator(*var.colors) + continuous_palette = \ + ContinuousPaletteGenerator(*desc.get_colors()) line = continuous_palette.getRGB(np.arange(0, 1, 1 / 256)) data = np.arange(0, 256, dtype=np.int8). \ reshape((1, 256)). \ @@ -150,10 +167,11 @@ def _column1(): img.data = data return img if role == Qt.ToolTipRole: - return "{} - {}".format(self._encode_color(var.colors[0]), - self._encode_color(var.colors[1])) + colors = desc.get_colors() + return f"{self._encode_color(colors[0])} " \ + f"- {self._encode_color(colors[1])}" if role == ColorRole: - return var.colors + return desc.get_colors() return None def _column2(): @@ -166,7 +184,7 @@ def _column2(): return None row, col = index.row(), index.column() - var = self.variables[row] + desc = self.variables[row] if 0 <= col <= 2: return [_column0, _column1, _column2][col]() @@ -183,9 +201,9 @@ def setData(self, index, value, role): return True def copy_to_all(self, index): - colors = self.variables[index.row()].colors - for row in range(self.n_rows()): - self.variables[row].colors = colors + colors = self.variables[index.row()].get_colors() + for desc in self.variables: + desc.colors = colors self.dataChanged.emit(self.index(0, 1), self.index(self.n_rows(), 1)) @@ -296,8 +314,8 @@ class Outputs: settingsHandler = settings.PerfectDomainContextHandler( match_values=settings.PerfectDomainContextHandler.MATCH_VALUES_ALL) - disc_data = settings.ContextSetting([]) - cont_data = settings.ContextSetting([]) + disc_colors = settings.ContextSetting([]) + cont_colors = settings.ContextSetting([]) color_settings = settings.Setting(None) selected_schema_index = settings.Setting(0) auto_apply = settings.Setting(True) @@ -308,8 +326,8 @@ def __init__(self): super().__init__() self.data = None self.orig_domain = self.domain = None - self.disc_colors = [] - self.cont_colors = [] + self.disc_dict = {} + self.cont_dict = {} box = gui.hBox(self.controlArea, "Discrete Variables") self.disc_model = DiscColorTableModel() @@ -334,19 +352,6 @@ def __init__(self): def sizeHint(): return QSize(500, 570) - def _create_proxies(self, variables): - part_vars = [] - for var in variables: - if var.is_discrete or var.is_continuous: - var = var.make_proxy() - if var.is_discrete: - var.values = var.values[:] - self.disc_colors.append(var) - else: - self.cont_colors.append(var) - part_vars.append(var) - return part_vars - @Inputs.data def set_data(self, data): """Handle data input signal""" @@ -356,53 +361,51 @@ def set_data(self, data): if data is None: self.data = self.domain = None else: - domain = self.orig_domain = data.domain - domain = Orange.data.Domain(self._create_proxies(domain.attributes), - self._create_proxies(domain.class_vars), - self._create_proxies(domain.metas)) - self.openContext(data) - self.data = data.transform(domain) + self.data = data + for var in chain(data.domain.variables, data.domain.metas): + if var.is_discrete: + self.disc_colors.append(AttrDesc.from_var(var)) + elif var.is_continuous: + self.cont_colors.append(AttrDesc.from_var(var)) + self.disc_model.set_data(self.disc_colors) self.cont_model.set_data(self.cont_colors) self.disc_view.resizeColumnsToContents() self.cont_view.resizeColumnsToContents() + self.openContext(data) + self.disc_dict = {k.var.name: k for k in self.disc_colors} + self.cont_dict = {k.var.name: k for k in self.cont_colors} self.unconditional_commit() - def storeSpecificSettings(self): - # Store the colors that were changed -- but not others - self.current_context.disc_data = \ - [(var.name, var.values, "colors" in var.attributes and var.colors) - for var in self.disc_colors] - self.current_context.cont_data = \ - [(var.name, "colors" in var.attributes and var.colors) - for var in self.cont_colors] - - def retrieveSpecificSettings(self): - disc_data = getattr(self.current_context, "disc_data", ()) - for var, (name, values, colors) in zip(self.disc_colors, disc_data): - var.name = name - var.values = values[:] - if colors is not False: - var.colors = colors - cont_data = getattr(self.current_context, "cont_data", ()) - for var, (name, colors) in zip(self.cont_colors, cont_data): - var.name = name - if colors is not False: - var.colors = colors - def _on_data_changed(self, *args): self.commit() def commit(self): - self.Outputs.data.send(self.data) + def make(var): + source = self.disc_dict if var.is_discrete else self.cont_dict + desc = source.get(var.name) + if desc: + var = var.copy( + name=desc.new_name, + **(dict(values=desc.values) if var.is_discrete else {})) + var.colors = desc.colors + return var + + domain = self.data.domain + new_domain = Orange.data.Domain( + [make(var) for var in domain.attributes], + [make(var) for var in domain.class_vars], + [make(var) for var in domain.metas]) + new_data = self.data.transform(new_domain) + self.Outputs.data.send(new_data) def send_report(self): """Send report""" - def _report_variables(variables, orig_variables): + def _report_variables(variables): from Orange.widgets.report import colored_square as square def was(n, o): - return n if n == o else "{} (was: {})".format(n, o) + return n if n == o else f"{n} (was: {o})" # definition of td element for continuous gradient # with support for pre-standard css (needed at least for Qt 4.8) @@ -420,36 +423,36 @@ def was(n, o): '">' rows = "" - for var, ovar in zip(variables, orig_variables): + for var in variables: if var.is_discrete: + desc = self.disc_dict[var.name] values = " \n".join( "{} {}". - format(square(*var.colors[i]), was(value, ovalue)) - for i, (value, ovalue) in - enumerate(zip(var.values, ovar.values))) + format(square(*color), was(value, old_value)) + for color, value, old_value in + zip(desc.colors, desc.values, var.values)) elif var.is_continuous: - col = var.colors + desc = self.cont_dict[var.name] + col = desc.colors colors = col[0][:3] + ("black, " * col[2], ) + col[1][:3] values = cont_tpl.format(*colors * len(defs)) else: continue - name = was(var.name, ovar.name) + names = was(desc.new_name, desc.name) rows += '\n' \ ' {}{}\n\n'. \ - format(name, values) + format(names, values) return rows if not self.data: return - domain = self.data.domain - orig_domain = self.orig_domain + dom = self.data.domain sections = ( - (name, _report_variables(vars, ovars)) - for name, vars, ovars in ( - ("Features", domain.attributes, orig_domain.attributes), - ("Outcome" + "s" * (len(domain.class_vars) > 1), - domain.class_vars, orig_domain.class_vars), - ("Meta attributes", domain.metas, orig_domain.metas))) + (name, _report_variables(vars)) + for name, vars in ( + ("Features", dom.attributes), + ("Outcome" + "s" * (len(dom.class_vars) > 1), dom.class_vars), + ("Meta attributes", dom.metas))) table = "".join("{}{}".format(name, rows) for name, rows in sections if rows) if table: