From 33cd68c804210b42396ac3a55884750917832e28 Mon Sep 17 00:00:00 2001 From: noahnovsak Date: Fri, 30 Sep 2022 12:29:38 +0200 Subject: [PATCH 1/6] owinteractions: new widget and table model --- orangecontrib/prototypes/ranktablemodel.py | 306 +++++++++++++ .../prototypes/widgets/owinteractions_new.py | 409 ++++++++++++++++++ .../widgets/tests/test_owinteractions_new.py | 26 ++ 3 files changed, 741 insertions(+) create mode 100644 orangecontrib/prototypes/ranktablemodel.py create mode 100644 orangecontrib/prototypes/widgets/owinteractions_new.py create mode 100644 orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py diff --git a/orangecontrib/prototypes/ranktablemodel.py b/orangecontrib/prototypes/ranktablemodel.py new file mode 100644 index 00000000..365ee6bd --- /dev/null +++ b/orangecontrib/prototypes/ranktablemodel.py @@ -0,0 +1,306 @@ +from numbers import Number, Integral +from typing import Iterable, Union +import numpy as np + +from AnyQt.QtCore import QModelIndex, Qt, QAbstractTableModel + +from Orange.data import Variable +from Orange.data.domain import Domain + +from Orange.widgets import gui +from Orange.widgets.utils.itemmodels import DomainModel + + +MAX_ROWS = int(1e9) # limits how many rows model will display + + +def _argsort(data: np.ndarray, order: Qt.SortOrder): + # same as ``_argsortData`` in AbstractSortModel, might combine? + if data.ndim == 1: + indices = np.argsort(data, kind="mergesort") + else: + indices = np.lexsort(data.T[::-1]) + if order == Qt.DescendingOrder: + indices = indices[::-1] + return indices + + +class ArrayTableModel(QAbstractTableModel): + """ + A proxy table model that stores and sorts its data with `numpy`, + thus providing higher speeds and better scaling. + + TODO: Could extend ``AbstractSortTableModel`` or ``PyTableModel``? + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.__sortInd = ... # type: np.ndarray + self.__sortColumn = -1 + self.__sortOrder = Qt.AscendingOrder + + self._data = None # type: np.ndarray + self._columns = 0 + self._rows = 0 + self._max_display_rows = self._max_data_rows = MAX_ROWS + self._headers = {} + + def columnData(self, column: Union[int, slice], apply_sort=False): + if apply_sort: + return self._data[:self._rows, column][self.__sortInd] + return self._data[:self._rows, column] + + def sortColumn(self): + return self.__sortColumn + + def sortOrder(self): + return self.__sortOrder + + def mapToSourceRows(self, rows: Union[int, slice, list, np.ndarray]): + if isinstance(self.__sortInd, np.ndarray) \ + and (isinstance(rows, (Integral, type(Ellipsis))) + or len(rows)): + rows = self.__sortInd[rows] + return rows + + def resetSorting(self): + self.sort(-1) + + def sort(self, column: int, order: Qt.SortOrder = Qt.AscendingOrder): + if self._data is None: + return + + indices = self._sort(column, order) + self.__sortColumn = column + self.__sortOrder = order + + self.setSortIndices(indices) + + def setSortIndices(self, indices: np.ndarray): + self.layoutAboutToBeChanged.emit([], QAbstractTableModel.VerticalSortHint) + self.__sortInd = indices + self.layoutChanged.emit([], QAbstractTableModel.VerticalSortHint) + + def _sort(self, column: int, order: Qt.SortOrder): + if column < 0: + return ... + + data = self.columnData(column) + return _argsort(data, order) + + def extendSortFrom(self, sorted_rows: int): + data = self.columnData(self.__sortColumn) + ind = np.arange(sorted_rows, self._rows) + order = 1 if self.__sortOrder == Qt.AscendingOrder else -1 + loc = np.searchsorted(data[:sorted_rows], + data[sorted_rows:self._rows], + sorter=self.__sortInd[::order]) + indices = np.insert(self.__sortInd[::order], loc, ind)[::order] + self.setSortIndices(indices) + + def rowCount(self, parent=QModelIndex(), *args, **kwargs): + return 0 if parent.isValid() else min(self._rows, self._max_display_rows) + + def columnCount(self, parent=QModelIndex(), *args, **kwargs): + return 0 if parent.isValid() else self._columns + + def data(self, index: QModelIndex, role=Qt.DisplayRole): + if not index.isValid(): + return + + row, column = self.mapToSourceRows(index.row()), index.column() + + try: + value = self._data[row, column] + except IndexError: + return + match role: + case Qt.EditRole: + return value + case Qt.DisplayRole: + if isinstance(value, Number) and not \ + (np.isnan(value) or np.isinf(value) or + isinstance(value, Integral)): + absval = abs(value) + strlen = len(str(int(absval))) + value = '{:.{}{}}'.format(value, + 2 if absval < .001 else + 3 if strlen < 2 else + 1 if strlen < 5 else + 0 if strlen < 6 else + 3, + 'f' if (absval == 0 or + absval >= .001 and + strlen < 6) + else 'e') + return str(value) + case Qt.DecorationRole if isinstance(value, Variable): + return gui.attributeIconDict[value] + case Qt.ToolTipRole: + return str(value) + + def setHorizontalHeaderLabels(self, labels: Iterable[str]): + self._headers[Qt.Horizontal] = tuple(labels) + + def setVertcalHeaderLabels(self, labels: Iterable[str]): + self._headers[Qt.Vertical] = tuple(labels) + + def headerData(self, section: int, orientation: Qt.Orientation, role=Qt.DisplayRole): + headers = self._headers.get(orientation) + + if headers and section < len(headers): + if orientation == Qt.Vertical: + section = self.mapToSourceRows(section) + if role in {Qt.DisplayRole, Qt.ToolTipRole}: + return headers[section] + + return super().headerData(section, orientation, role) + + def __len__(self): + return self._rows + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, item): + return self._data[item] + + def initialize(self, data: list[list[float]]): + self.beginResetModel() + self._data = np.array(data) + self._rows, self._columns = self._data.shape + self.resetSorting() + self.endResetModel() + + def clear(self): + self.beginResetModel() + self._data = None + self._rows = self._columns = 0 + self.resetSorting() + self.endResetModel() + + def append(self, rows: list[list[float]]): + if not isinstance(self._data, np.ndarray): + return self.initialize(rows) + + n_rows = len(rows) + if n_rows == 0: + print("nothing to add") + return + n_data = len(self._data) + insert = self._rows < self._max_display_rows + + if insert: + self.beginInsertRows(QModelIndex(), self._rows, min(self._max_display_rows, self._rows + n_rows) - 1) + + if self._rows + n_rows >= n_data: + n_data = min(max(n_data + n_rows, 2 * n_data), self._max_data_rows) + ar = np.full((n_data, self._columns), np.nan) + ar[:self._rows] = self._data[:self._rows] + self._data = ar + + self._data[self._rows:self._rows + n_rows] = rows + self._rows += n_rows + + if self.__sortColumn >= 0: + old_rows = self._rows - n_rows + self.extendSortFrom(old_rows) + + if insert: + self.endInsertRows() + + +class RankModel(ArrayTableModel): + """ + Extends ``ArrayTableModel`` with filtering and other specific + features for ``VizRankDialog`` type widgets, to display scores for + combinations of attributes. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.__filterInd = ... # type: np.ndarray + self.__filterStr = "" + + self.domain = None # type: Domain + self.domain_model = DomainModel(DomainModel.ATTRIBUTES) + + def set_domain(self, domain: Domain, **kwargs): + self.__dict__.update(kwargs) + self.domain = domain + self.domain_model.set_domain(domain) + n_attrs = len(domain.attributes) + self._max_data_rows = n_attrs * (n_attrs - 1) // 2 + + def mapToSourceRows(self, rows): + if isinstance(self.__filterInd, np.ndarray) \ + and (isinstance(rows, (Integral, type(Ellipsis))) + or len(rows)): + rows = self.__filterInd[rows] + return super().mapToSourceRows(rows) + + def resetFiltering(self): + self.filter("") + + def filter(self, text: str): + if self._data is None: + return + + if not text: + self.__filterInd = indices = ... + self.__filterStr = "" + self._max_display_rows = MAX_ROWS + else: + self.__filterStr = text + indices = self._filter(text) + + self.setFilterIndices(indices) + + def setFilterIndices(self, indices: np.ndarray): + self.layoutAboutToBeChanged.emit([]) + if isinstance(indices, np.ndarray): + self.__filterInd = indices + self._max_display_rows = len(indices) + self.layoutChanged.emit([]) + + def setSortIndices(self, indices: np.ndarray): + super().setSortIndices(indices) + + # sorting messes up the filter indices, so they + # must also be updated + self.layoutAboutToBeChanged.emit([]) + if isinstance(self.__filterInd, np.ndarray): + filter_indices = self._filter(self.__filterStr) + self.__filterInd = filter_indices + self._max_display_rows = len(filter_indices) + self.layoutChanged.emit([]) + + def _filter(self, text: str): + attr = [i for i, attr in enumerate(self.domain.attributes) + if str(text).lower() in attr.name.lower()] + + attr_data = self.columnData(slice(-2, None), apply_sort=True) + valid = np.isin(attr_data, attr).any(axis=1) + + return valid.nonzero()[0] + + def append(self, rows): + super().append(rows) + + if isinstance(self.__filterInd, np.ndarray): + self.resetFiltering() + + def data(self, index: QModelIndex, role=Qt.DisplayRole): + if not index.isValid(): + return + + row, column = self.mapToSourceRows(index.row()), index.column() + try: + value = self._data[row, column] + except IndexError: + return + + if column >= self.columnCount() - 2 and role != Qt.EditRole: + return self.domain_model.data(self.domain_model.index(int(value)), role) + + return super().data(index, role) diff --git a/orangecontrib/prototypes/widgets/owinteractions_new.py b/orangecontrib/prototypes/widgets/owinteractions_new.py new file mode 100644 index 00000000..ae3da4ef --- /dev/null +++ b/orangecontrib/prototypes/widgets/owinteractions_new.py @@ -0,0 +1,409 @@ +import copy +from itertools import chain +from threading import Lock, Timer +from typing import Callable, Optional, Iterable +import numpy as np + +from AnyQt.QtGui import QColor, QPainter, QPen +from AnyQt.QtCore import QModelIndex, Qt, QLineF +from AnyQt.QtWidgets import QTableView, QHeaderView, \ + QStyleOptionViewItem, QApplication, QStyle + +from Orange.data import Table, Variable +from Orange.preprocess import Discretize, Remove +from Orange.widgets import gui +from Orange.widgets.widget import OWWidget, AttributeList, Msg +from Orange.widgets.utils.widgetpreview import WidgetPreview +from Orange.widgets.utils.signals import Input, Output +from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState +from Orange.widgets.utils.itemmodels import DomainModel +from Orange.widgets.settings import Setting, ContextSetting, DomainContextHandler + +from orangecontrib.prototypes.ranktablemodel import RankModel +from orangecontrib.prototypes.interactions import InteractionScorer + + +class ModelQueue: + """ + Another queueing object, similar to ``queue.Queue``. + The main difference is that ``get()`` returns all its + contents at the same time, instead of one by one. + """ + def __init__(self): + self.lock = Lock() + self.model = [] + self.state = None + + def put(self, row, state): + with self.lock: + self.model.append(row) + self.state = state + + def get(self): + with self.lock: + model, self.model = self.model, [] + state, self.state = self.state, None + return model, state + + +def run(compute_score: Callable, row_for_state: Callable, + iterate_states: Callable, saved_state: Optional[Iterable], + progress: int, state_count: int, task: TaskState): + """ + Replaces ``run_vizrank``, with some minor adjustments. + - ``ModelQueue`` replaces ``queue.Queue`` + - `row_for_state` parameter added + - `scores` parameter removed + """ + task.set_status("Getting combinations...") + task.set_progress_value(0.1) + states = iterate_states(saved_state) + + task.set_status("Getting scores...") + queue = ModelQueue() + can_set_partial_result = True + + def do_work(st, next_st): + try: + score = compute_score(st) + if score is not None: + queue.put(row_for_state(score, st), next_st) + except Exception: + pass + + def reset_flag(): + nonlocal can_set_partial_result + can_set_partial_result = True + + state = None + next_state = next(states) + try: + while True: + if task.is_interruption_requested(): + return queue.get() + task.set_progress_value(progress * 100 // state_count) + progress += 1 + state = copy.copy(next_state) + next_state = copy.copy(next(states)) + do_work(state, next_state) + # for simple scores (e.g. correlations widget) and many feature + # combinations, the 'partial_result_ready' signal (emitted by + # invoking 'task.set_partial_result') was emitted too frequently + # for a longer period of time and therefore causing the widget + # being unresponsive + if can_set_partial_result: + task.set_partial_result(queue.get()) + can_set_partial_result = False + Timer(0.05, reset_flag).start() + except StopIteration: + do_work(state, None) + task.set_partial_result(queue.get()) + return queue.get() + + +class Heuristic: + RANDOM, INFO_GAIN = 0, 1 + type = {RANDOM: "Random Search", + INFO_GAIN: "Information Gain Heuristic"} + + def __init__(self, weights, type=None): + self.n_attributes = len(weights) + self.attributes = np.arange(self.n_attributes) + if type == self.RANDOM: + np.random.shuffle(self.attributes) + if type == self.INFO_GAIN: + self.attributes = self.attributes[np.argsort(weights)] + + def generate_states(self): + # prioritize two mid ranked attributes over highest first + for s in range(1, self.n_attributes * (self.n_attributes - 1) // 2): + for i in range(max(s - self.n_attributes + 1, 0), (s + 1) // 2): + yield self.attributes[i], self.attributes[s - i] + + def get_states(self, initial_state): + states = self.generate_states() + if initial_state is not None: + while next(states) != initial_state: + pass + return chain([initial_state], states) + return states + + +class InteractionItemDelegate(gui.TableBarItem): + def paint(self, painter: QPainter, option: QStyleOptionViewItem, + index: QModelIndex) -> None: + opt = QStyleOptionViewItem(option) + self.initStyleOption(opt, index) + widget = option.widget + style = QApplication.style() if widget is None else widget.style() + pen = QPen(QColor("#46befa"), 5, Qt.SolidLine, Qt.RoundCap) + line = QLineF() + self.__style = style + text = opt.text + opt.text = "" + style.drawControl(QStyle.CE_ItemViewItem, opt, painter, widget) + textrect = style.subElementRect( + QStyle.SE_ItemViewItemText, opt, widget) + + interaction = self.cachedData(index, Qt.EditRole) + # only draw bars for first column + if index.column() == 0 and interaction is not None: + rect = option.rect + pw = self.penWidth + textoffset = pw + 2 + baseline = rect.bottom() - textoffset / 2 + origin = rect.left() + 3 + pw / 2 # + half pen width for the round line cap + width = rect.width() - 3 - pw + + def draw_line(start, length): + line.setLine(origin + start, baseline, origin + start + length, baseline) + painter.drawLine(line) + + scorer = index.model().scorer + attr1 = self.cachedData(index.siblingAtColumn(2), Qt.EditRole) + attr2 = self.cachedData(index.siblingAtColumn(3), Qt.EditRole) + l_bar = scorer.normalize(scorer.information_gain[int(attr1)]) + r_bar = scorer.normalize(scorer.information_gain[int(attr2)]) + # negative information gains stem from issues in interaction + # calculation and may cause bars reaching out of intended area + l_bar, r_bar = width * max(l_bar, 0), width * max(r_bar, 0) + interaction *= width + + pen.setWidth(pw) + painter.save() + painter.setRenderHint(QPainter.Antialiasing) + painter.setPen(pen) + draw_line(0, l_bar) + draw_line(l_bar + interaction, r_bar) + pen.setColor(QColor("#aaf22b") if interaction >= 0 else QColor("#ffaa7f")) + painter.setPen(pen) + draw_line(l_bar, interaction) + painter.restore() + textrect.adjust(0, 0, 0, -textoffset) + + opt.text = text + self.drawViewItemText(style, painter, opt, textrect) + + +class OWInteractions(OWWidget, ConcurrentWidgetMixin): + name = "Interactions New" + description = "Compute all pairwise attribute interactions." + icon = "icons/Interactions.svg" + category = "Unsupervised" + + class Inputs: + data = Input("Data", Table) + + class Outputs: + features = Output("Features", AttributeList) + + settingsHandler = DomainContextHandler() + selection = ContextSetting([]) + filter_text: str + filter_text = ContextSetting("") + feature: Variable + feature = ContextSetting(None) + heuristic_type: int + heuristic_type = Setting(0) + + want_main_area = False + want_control_area = True + + class Information(OWWidget.Information): + removed_cons_feat = Msg("Constant features have been removed.") + + class Warning(OWWidget.Warning): + not_enough_vars = Msg("At least two features are needed.") + not_enough_inst = Msg("At least two instances are needed.") + no_class_var = Msg("Target feature missing") + + def __init__(self): + OWWidget.__init__(self) + ConcurrentWidgetMixin.__init__(self) + + self.keep_running = True + self.saved_state = None + self.progress = 0 + + self.data = None # type: Table + self.pp_data = None # type: Table + self.n_attrs = 0 + + self.scorer = None + self.heuristic = None + self.feature_index = None + + gui.comboBox(self.controlArea, self, "heuristic_type", + items=Heuristic.type.values(), + callback=self.on_heuristic_combo_changed,) + + self.feature_model = DomainModel(order=DomainModel.ATTRIBUTES, + separators=False, + placeholder="(All combinations)") + gui.comboBox(self.controlArea, self, "feature", + callback=self.on_feature_combo_changed, + model=self.feature_model, searchable=True) + + self.filter = gui.lineEdit(self.controlArea, self, "filter_text", + callback=self.on_filter_changed, + callbackOnType=True) + self.filter.setPlaceholderText("Filter ...") + + self.model = RankModel() + self.model.setHorizontalHeaderLabels(( + "Interaction", "Information Gain", "Feature 1", "Feature 2" + )) + self.rank_table = view = QTableView(selectionBehavior=QTableView.SelectRows, + selectionMode=QTableView.SingleSelection, + showGrid=False, + editTriggers=gui.TableView.NoEditTriggers) + view.setSortingEnabled(True) + view.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) + view.setItemDelegate(InteractionItemDelegate()) + view.setModel(self.model) + view.selectionModel().selectionChanged.connect(self.on_selection_changed) + self.controlArea.layout().addWidget(view) + + self.button = gui.button(self.controlArea, self, "Start", callback=self.toggle) + self.button.setEnabled(False) + + @Inputs.data + def set_data(self, data): + self.closeContext() + self.clear_messages() + self.selection = [] + self.data = data + self.pp_data = None + self.n_attrs = 0 + if data is not None: + if len(data) < 2: + self.Warning.not_enough_inst() + elif data.Y.size == 0: + self.Warning.no_class_var() + else: + remover = Remove(Remove.RemoveConstant) + pp_data = Discretize()(remover(data)) + if remover.attr_results["removed"]: + self.Information.removed_cons_feat() + if len(pp_data.domain.attributes) < 2: + self.Warning.not_enough_vars() + else: + self.pp_data = pp_data + self.n_attrs = len(pp_data.domain.attributes) + self.scorer = InteractionScorer(pp_data) + self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) + self.model.set_domain(pp_data.domain, scorer=self.scorer) + self.feature_model.set_domain(self.pp_data and self.pp_data.domain) + self.openContext(self.pp_data) + self.initialize() + + def initialize(self): + if self.task is not None: + self.keep_running = False + self.cancel() + self.keep_running = True + self.saved_state = None + self.progress = 0 + self.progressBarFinished() + self.model.clear() + self.filter.setText("") + self.button.setText("Start") + if self.pp_data is not None: + self.toggle() + + def commit(self): + if self.data is None: + self.Outputs.features.send(None) + return + + self.Outputs.features.send(AttributeList( + [self.data.domain[attr] for attr in self.selection])) + + def toggle(self): + self.keep_running = not self.keep_running + if not self.keep_running: + self.button.setText("Pause") + self.button.repaint() + self.progressBarInit() + self.filter.setText("") + self.filter.setEnabled(False) + self.start(run, self.compute_score, self.row_for_state, + self.iterate_states, self.saved_state, + self.progress, self.state_count()) + else: + self.button.setText("Continue") + self.button.repaint() + self.cancel() + self.progressBarFinished() + self.filter.setEnabled(True) + + def on_selection_changed(self, selected): + self.selection = [self.model.data(ind) for ind in selected.indexes()[-2:]] + self.commit() + + def on_filter_changed(self): + self.model.filter(self.filter_text) + + def on_feature_combo_changed(self): + self.feature_index = self.feature and self.pp_data.domain.index(self.feature) + self.initialize() + + def on_heuristic_combo_changed(self): + if self.pp_data is not None: + self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) + self.initialize() + + def compute_score(self, state): + scores = (self.scorer(*state), + self.scorer.information_gain[state[0]], + self.scorer.information_gain[state[1]]) + return tuple(self.scorer.normalize(score) for score in scores) + + @staticmethod + def row_for_state(score, state): + return [score[0], sum(score)] + list(state) + + def iterate_states(self, initial_state): + if self.feature is not None: + return self._iterate_by_feature(initial_state) + if self.heuristic is not None: + return self.heuristic.get_states(initial_state) + return self._iterate_all(initial_state) + + def _iterate_all(self, initial_state): + i0, j0 = initial_state or (0, 0) + for i in range(i0, self.n_attrs): + for j in range(j0, i): + yield i, j + j0 = 0 + + def _iterate_by_feature(self, initial_state): + _, j0 = initial_state or (0, 0) + for j in range(j0, self.n_attrs): + if j != self.feature_index: + yield self.feature_index, j + + def state_count(self): + if self.feature is None: + return self.n_attrs * (self.n_attrs - 1) // 2 + return self.n_attrs + + def on_partial_result(self, result): + add_to_model, latest_state = result + if add_to_model: + self.saved_state = latest_state + self.model.append(add_to_model) + self.progress = len(self.model) + self.progressBarSet(self.progress * 100 // self.state_count()) + + def on_done(self, result): + self.button.setText("Finished") + self.button.setEnabled(False) + self.filter.setEnabled(True) + + def send_report(self): + self.report_table("Interactions", self.rank_table) + + +if __name__ == "__main__": # pragma: no cover + WidgetPreview(OWInteractions).run(Table("iris")) diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py new file mode 100644 index 00000000..4cc212ab --- /dev/null +++ b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py @@ -0,0 +1,26 @@ +from Orange.data import Table +from Orange.widgets.tests.base import WidgetTest + +from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions + + +class TestOWInteractions(WidgetTest): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.iris = Table("iris") # continuous data + cls.zoo = Table("zoo") # discrete data + + def setUp(self): + self.widget = self.create_widget(OWInteractions) + + def test_input_data(self): + """Check interaction table""" + self.send_signal(self.widget.Inputs.data, None) + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertEqual(self.widget.model.rowCount(), 0) + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 4) + self.assertEqual(self.widget.model.rowCount(), 6) From a3bc25813b608d2e5950707e69ac86a4f4ccda5b Mon Sep 17 00:00:00 2001 From: noahnovsak Date: Wed, 5 Oct 2022 09:56:43 +0200 Subject: [PATCH 2/6] added tests --- .../prototypes/widgets/owinteractions_new.py | 7 +- .../widgets/tests/test_owinteractions_new.py | 115 +++++++++++++++++- 2 files changed, 119 insertions(+), 3 deletions(-) diff --git a/orangecontrib/prototypes/widgets/owinteractions_new.py b/orangecontrib/prototypes/widgets/owinteractions_new.py index ae3da4ef..1d2c3e6d 100644 --- a/orangecontrib/prototypes/widgets/owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/owinteractions_new.py @@ -337,6 +337,10 @@ def toggle(self): self.progressBarFinished() self.filter.setEnabled(True) + def _select_first_if_none(self): + if not self.rank_table.selectedIndexes(): + self.rank_table.selectRow(0) + def on_selection_changed(self, selected): self.selection = [self.model.data(ind) for ind in selected.indexes()[-2:]] self.commit() @@ -366,7 +370,7 @@ def row_for_state(score, state): def iterate_states(self, initial_state): if self.feature is not None: return self._iterate_by_feature(initial_state) - if self.heuristic is not None: + if self.n_attrs > 3 and self.heuristic is not None: return self.heuristic.get_states(initial_state) return self._iterate_all(initial_state) @@ -400,6 +404,7 @@ def on_done(self, result): self.button.setText("Finished") self.button.setEnabled(False) self.filter.setEnabled(True) + self._select_first_if_none() def send_report(self): self.report_table("Interactions", self.rank_table) diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py index 4cc212ab..95c66db4 100644 --- a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py @@ -1,5 +1,10 @@ -from Orange.data import Table +from unittest.mock import Mock + +import numpy as np + +from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.widget import AttributeList from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions @@ -15,7 +20,7 @@ def setUp(self): self.widget = self.create_widget(OWInteractions) def test_input_data(self): - """Check interaction table""" + """Check table on input data""" self.send_signal(self.widget.Inputs.data, None) self.assertEqual(self.widget.model.columnCount(), 0) self.assertEqual(self.widget.model.rowCount(), 0) @@ -24,3 +29,109 @@ def test_input_data(self): self.process_events() self.assertEqual(self.widget.model.columnCount(), 4) self.assertEqual(self.widget.model.rowCount(), 6) + + def test_input_data_one_feature(self): + """Check table on input data with single attribute""" + self.send_signal(self.widget.Inputs.data, self.iris[:, [0, 4]]) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Warning.not_enough_vars.is_shown()) + + def test_input_data_no_target(self): + """Check table on input data without target""" + self.send_signal(self.widget.Inputs.data, self.iris[:, :-1]) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertTrue(self.widget.Warning.no_class_var.is_shown()) + + def test_input_data_one_instance(self): + """Check table on input data with single instance""" + self.send_signal(self.widget.Inputs.data, self.iris[:1]) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + self.assertTrue(self.widget.Warning.not_enough_inst.is_shown()) + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Warning.not_enough_inst.is_shown()) + + def test_input_data_constant_features(self): + """Check table on input data with constant columns""" + x = np.array([[0, 2, 1], + [0, 2, 0], + [0, 0, 1], + [0, 1, 2]]) + y = np.array([1, 2, 1, 0]) + labels = ["a", "b", "c"] + domain_disc = Domain([DiscreteVariable(str(i), labels) for i in range(3)], + DiscreteVariable("cls", labels)) + domain_cont = Domain([ContinuousVariable(str(i)) for i in range(3)], + DiscreteVariable("cls", labels)) + + self.send_signal(self.widget.Inputs.data, Table(domain_disc, x, y)) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 3) + self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + + self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 1) + self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) + + x = np.ones((4, 3), dtype=float) + self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) + self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) + + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + + def test_output_features(self): + """Check output""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + features = self.get_output(self.widget.Outputs.features) + self.assertIsInstance(features, AttributeList) + self.assertEqual(len(features), 2) + + def test_input_changed(self): + """Check commit on input""" + self.widget.commit = Mock() + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.widget.commit.assert_called_once() + + x = np.array([[0, 1, 0], + [1, 1, 2], + [0, 2, 0], + [0, 0, 2]]) + y = np.array([1, 2, 2, 0]) + domain = Domain([DiscreteVariable(str(i), ["a", "b", "c"]) for i in range(3)], + DiscreteVariable("cls")) + + self.widget.commit.reset_mock() + self.send_signal(self.widget.Inputs.data, Table(domain, x, y)) + self.wait_until_finished() + self.process_events() + self.widget.commit.assert_called_once() + + def test_feature_combo(self): + """Check feature combobox""" + feature_combo = self.widget.controls.feature + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertEqual(len(feature_combo.model()), 5) + + self.wait_until_stop_blocking() + self.send_signal(self.widget.Inputs.data, self.zoo) + self.assertEqual(len(feature_combo.model()), 17) From 67032361696792ff7adecf1ede258aecc53de7ca Mon Sep 17 00:00:00 2001 From: noahnovsak Date: Thu, 6 Oct 2022 10:52:08 +0200 Subject: [PATCH 3/6] fix filtering --- orangecontrib/prototypes/ranktablemodel.py | 253 ++++-------------- .../prototypes/widgets/owinteractions_new.py | 38 +-- 2 files changed, 73 insertions(+), 218 deletions(-) diff --git a/orangecontrib/prototypes/ranktablemodel.py b/orangecontrib/prototypes/ranktablemodel.py index 365ee6bd..ffbb3984 100644 --- a/orangecontrib/prototypes/ranktablemodel.py +++ b/orangecontrib/prototypes/ranktablemodel.py @@ -1,161 +1,63 @@ -from numbers import Number, Integral -from typing import Iterable, Union import numpy as np -from AnyQt.QtCore import QModelIndex, Qt, QAbstractTableModel +from AnyQt.QtCore import QModelIndex, Qt -from Orange.data import Variable from Orange.data.domain import Domain - -from Orange.widgets import gui -from Orange.widgets.utils.itemmodels import DomainModel +from Orange.widgets.utils.itemmodels import DomainModel, PyTableModel MAX_ROWS = int(1e9) # limits how many rows model will display -def _argsort(data: np.ndarray, order: Qt.SortOrder): - # same as ``_argsortData`` in AbstractSortModel, might combine? - if data.ndim == 1: - indices = np.argsort(data, kind="mergesort") - else: - indices = np.lexsort(data.T[::-1]) - if order == Qt.DescendingOrder: - indices = indices[::-1] - return indices - - -class ArrayTableModel(QAbstractTableModel): +class ArrayTableModel(PyTableModel): """ - A proxy table model that stores and sorts its data with `numpy`, - thus providing higher speeds and better scaling. + A model for displaying 2-dimensional numpy arrays in ``QTableView`` objects. + + This model extends ``PyTableModel`` to gain access to the following methods: + ``_roleData``, ``flags``, ``setData``, ``data``, ``setHorizontalHeaderLabels``, + ``setVerticalHeaderLabels``, and ``headerData``. + Other, unlisted methods aren't guaranteed to work and should be used with care. - TODO: Could extend ``AbstractSortTableModel`` or ``PyTableModel``? + Also requires access to private members of ``AbstractSortTableModel`` directly; + ``__sortInd`` is used for sorting and potential filtering, and ``__init__`` because the + parent implementation immediately wraps a list, which this model does not use. """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + super(PyTableModel, self).__init__(*args, **kwargs) - self.__sortInd = ... # type: np.ndarray - self.__sortColumn = -1 - self.__sortOrder = Qt.AscendingOrder + self._headers = {} + self._roleData = {} + self._editable = kwargs.get("editable") self._data = None # type: np.ndarray self._columns = 0 self._rows = 0 self._max_display_rows = self._max_data_rows = MAX_ROWS - self._headers = {} - - def columnData(self, column: Union[int, slice], apply_sort=False): - if apply_sort: - return self._data[:self._rows, column][self.__sortInd] - return self._data[:self._rows, column] - - def sortColumn(self): - return self.__sortColumn - - def sortOrder(self): - return self.__sortOrder - - def mapToSourceRows(self, rows: Union[int, slice, list, np.ndarray]): - if isinstance(self.__sortInd, np.ndarray) \ - and (isinstance(rows, (Integral, type(Ellipsis))) - or len(rows)): - rows = self.__sortInd[rows] - return rows - - def resetSorting(self): - self.sort(-1) - - def sort(self, column: int, order: Qt.SortOrder = Qt.AscendingOrder): - if self._data is None: - return - - indices = self._sort(column, order) - self.__sortColumn = column - self.__sortOrder = order - - self.setSortIndices(indices) - def setSortIndices(self, indices: np.ndarray): - self.layoutAboutToBeChanged.emit([], QAbstractTableModel.VerticalSortHint) - self.__sortInd = indices - self.layoutChanged.emit([], QAbstractTableModel.VerticalSortHint) + def sortInd(self): + return self._AbstractSortTableModel__sortInd - def _sort(self, column: int, order: Qt.SortOrder): - if column < 0: - return ... - - data = self.columnData(column) - return _argsort(data, order) + def sortColumnData(self, column): + return self._data[:self._rows, column] def extendSortFrom(self, sorted_rows: int): - data = self.columnData(self.__sortColumn) - ind = np.arange(sorted_rows, self._rows) - order = 1 if self.__sortOrder == Qt.AscendingOrder else -1 + data = self.sortColumnData(self.sortColumn()) + new_ind = np.arange(sorted_rows, self._rows) + order = 1 if self.sortOrder() == Qt.AscendingOrder else -1 + sorter = self.sortInd()[::order] + new_sorter = np.argsort(data[sorted_rows:]) loc = np.searchsorted(data[:sorted_rows], - data[sorted_rows:self._rows], - sorter=self.__sortInd[::order]) - indices = np.insert(self.__sortInd[::order], loc, ind)[::order] + data[sorted_rows:][new_sorter], + sorter=sorter) + indices = np.insert(sorter, loc, new_ind[new_sorter])[::order] self.setSortIndices(indices) - def rowCount(self, parent=QModelIndex(), *args, **kwargs): + def rowCount(self, parent=QModelIndex()): return 0 if parent.isValid() else min(self._rows, self._max_display_rows) - def columnCount(self, parent=QModelIndex(), *args, **kwargs): + def columnCount(self, parent=QModelIndex()): return 0 if parent.isValid() else self._columns - def data(self, index: QModelIndex, role=Qt.DisplayRole): - if not index.isValid(): - return - - row, column = self.mapToSourceRows(index.row()), index.column() - - try: - value = self._data[row, column] - except IndexError: - return - match role: - case Qt.EditRole: - return value - case Qt.DisplayRole: - if isinstance(value, Number) and not \ - (np.isnan(value) or np.isinf(value) or - isinstance(value, Integral)): - absval = abs(value) - strlen = len(str(int(absval))) - value = '{:.{}{}}'.format(value, - 2 if absval < .001 else - 3 if strlen < 2 else - 1 if strlen < 5 else - 0 if strlen < 6 else - 3, - 'f' if (absval == 0 or - absval >= .001 and - strlen < 6) - else 'e') - return str(value) - case Qt.DecorationRole if isinstance(value, Variable): - return gui.attributeIconDict[value] - case Qt.ToolTipRole: - return str(value) - - def setHorizontalHeaderLabels(self, labels: Iterable[str]): - self._headers[Qt.Horizontal] = tuple(labels) - - def setVertcalHeaderLabels(self, labels: Iterable[str]): - self._headers[Qt.Vertical] = tuple(labels) - - def headerData(self, section: int, orientation: Qt.Orientation, role=Qt.DisplayRole): - headers = self._headers.get(orientation) - - if headers and section < len(headers): - if orientation == Qt.Vertical: - section = self.mapToSourceRows(section) - if role in {Qt.DisplayRole, Qt.ToolTipRole}: - return headers[section] - - return super().headerData(section, orientation, role) - def __len__(self): return self._rows @@ -167,8 +69,9 @@ def __getitem__(self, item): def initialize(self, data: list[list[float]]): self.beginResetModel() - self._data = np.array(data) + self._data = np.asarray(data) self._rows, self._columns = self._data.shape + self._roleData = self._RoleData() self.resetSorting() self.endResetModel() @@ -176,6 +79,7 @@ def clear(self): self.beginResetModel() self._data = None self._rows = self._columns = 0 + self._roleData.clear() self.resetSorting() self.endResetModel() @@ -185,7 +89,6 @@ def append(self, rows: list[list[float]]): n_rows = len(rows) if n_rows == 0: - print("nothing to add") return n_data = len(self._data) insert = self._rows < self._max_display_rows @@ -202,105 +105,49 @@ def append(self, rows: list[list[float]]): self._data[self._rows:self._rows + n_rows] = rows self._rows += n_rows - if self.__sortColumn >= 0: - old_rows = self._rows - n_rows - self.extendSortFrom(old_rows) - if insert: self.endInsertRows() + if self.sortColumn() >= 0: + old_rows = self._rows - n_rows + self.extendSortFrom(old_rows) + class RankModel(ArrayTableModel): """ - Extends ``ArrayTableModel`` with filtering and other specific - features for ``VizRankDialog`` type widgets, to display scores for - combinations of attributes. + Extends ``ArrayTableModel`` for ``VizRankDialog`` type widgets, + to display scores for combinations of attributes. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.__filterInd = ... # type: np.ndarray - self.__filterStr = "" - self.domain = None # type: Domain self.domain_model = DomainModel(DomainModel.ATTRIBUTES) - def set_domain(self, domain: Domain, **kwargs): - self.__dict__.update(kwargs) + def set_domain(self, domain: Domain): self.domain = domain self.domain_model.set_domain(domain) n_attrs = len(domain.attributes) self._max_data_rows = n_attrs * (n_attrs - 1) // 2 - def mapToSourceRows(self, rows): - if isinstance(self.__filterInd, np.ndarray) \ - and (isinstance(rows, (Integral, type(Ellipsis))) - or len(rows)): - rows = self.__filterInd[rows] - return super().mapToSourceRows(rows) - - def resetFiltering(self): - self.filter("") - - def filter(self, text: str): + def resetSorting(self): if self._data is None: - return - - if not text: - self.__filterInd = indices = ... - self.__filterStr = "" - self._max_display_rows = MAX_ROWS + self.sort(-1) else: - self.__filterStr = text - indices = self._filter(text) - - self.setFilterIndices(indices) - - def setFilterIndices(self, indices: np.ndarray): - self.layoutAboutToBeChanged.emit([]) - if isinstance(indices, np.ndarray): - self.__filterInd = indices - self._max_display_rows = len(indices) - self.layoutChanged.emit([]) - - def setSortIndices(self, indices: np.ndarray): - super().setSortIndices(indices) - - # sorting messes up the filter indices, so they - # must also be updated - self.layoutAboutToBeChanged.emit([]) - if isinstance(self.__filterInd, np.ndarray): - filter_indices = self._filter(self.__filterStr) - self.__filterInd = filter_indices - self._max_display_rows = len(filter_indices) - self.layoutChanged.emit([]) - - def _filter(self, text: str): - attr = [i for i, attr in enumerate(self.domain.attributes) - if str(text).lower() in attr.name.lower()] - - attr_data = self.columnData(slice(-2, None), apply_sort=True) - valid = np.isin(attr_data, attr).any(axis=1) - - return valid.nonzero()[0] - - def append(self, rows): - super().append(rows) - - if isinstance(self.__filterInd, np.ndarray): - self.resetFiltering() + self.sort(0, Qt.DescendingOrder) def data(self, index: QModelIndex, role=Qt.DisplayRole): if not index.isValid(): return - row, column = self.mapToSourceRows(index.row()), index.column() - try: - value = self._data[row, column] - except IndexError: - return + column = index.column() if column >= self.columnCount() - 2 and role != Qt.EditRole: - return self.domain_model.data(self.domain_model.index(int(value)), role) + try: + row = self.mapToSourceRows(index.row()) + value = self.domain_model.index(int(self._data[row, column])) + return self.domain_model.data(value, role) + except IndexError: + return return super().data(index, role) diff --git a/orangecontrib/prototypes/widgets/owinteractions_new.py b/orangecontrib/prototypes/widgets/owinteractions_new.py index 1d2c3e6d..253bda20 100644 --- a/orangecontrib/prototypes/widgets/owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/owinteractions_new.py @@ -5,9 +5,9 @@ import numpy as np from AnyQt.QtGui import QColor, QPainter, QPen -from AnyQt.QtCore import QModelIndex, Qt, QLineF +from AnyQt.QtCore import QModelIndex, Qt, QLineF, QSortFilterProxyModel from AnyQt.QtWidgets import QTableView, QHeaderView, \ - QStyleOptionViewItem, QApplication, QStyle + QStyleOptionViewItem, QApplication, QStyle, QLineEdit from Orange.data import Table, Variable from Orange.preprocess import Discretize, Remove @@ -185,6 +185,13 @@ def draw_line(start, length): self.drawViewItemText(style, painter, opt, textrect) +class FilterProxy(QSortFilterProxyModel): + scorer = None + + def sort(self, *args, **kwargs): + self.sourceModel().sort(*args, **kwargs) + + class OWInteractions(OWWidget, ConcurrentWidgetMixin): name = "Interactions New" description = "Compute all pairwise attribute interactions." @@ -199,8 +206,6 @@ class Outputs: settingsHandler = DomainContextHandler() selection = ContextSetting([]) - filter_text: str - filter_text = ContextSetting("") feature: Variable feature = ContextSetting(None) heuristic_type: int @@ -244,15 +249,17 @@ def __init__(self): callback=self.on_feature_combo_changed, model=self.feature_model, searchable=True) - self.filter = gui.lineEdit(self.controlArea, self, "filter_text", - callback=self.on_filter_changed, - callbackOnType=True) + self.filter = QLineEdit() self.filter.setPlaceholderText("Filter ...") + self.filter.textChanged.connect(self.on_filter_changed) + self.controlArea.layout().addWidget(self.filter) self.model = RankModel() - self.model.setHorizontalHeaderLabels(( - "Interaction", "Information Gain", "Feature 1", "Feature 2" - )) + self.model.setHorizontalHeaderLabels(["Interaction", "Information Gain", + "Feature 1", "Feature 2"]) + self.proxy = FilterProxy(filterCaseSensitivity=Qt.CaseInsensitive) + self.proxy.setSourceModel(self.model) + self.proxy.setFilterKeyColumn(-1) self.rank_table = view = QTableView(selectionBehavior=QTableView.SelectRows, selectionMode=QTableView.SingleSelection, showGrid=False, @@ -260,7 +267,7 @@ def __init__(self): view.setSortingEnabled(True) view.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) view.setItemDelegate(InteractionItemDelegate()) - view.setModel(self.model) + view.setModel(self.proxy) view.selectionModel().selectionChanged.connect(self.on_selection_changed) self.controlArea.layout().addWidget(view) @@ -292,7 +299,8 @@ def set_data(self, data): self.n_attrs = len(pp_data.domain.attributes) self.scorer = InteractionScorer(pp_data) self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) - self.model.set_domain(pp_data.domain, scorer=self.scorer) + self.model.set_domain(pp_data.domain) + self.proxy.scorer = self.scorer self.feature_model.set_domain(self.pp_data and self.pp_data.domain) self.openContext(self.pp_data) self.initialize() @@ -308,6 +316,7 @@ def initialize(self): self.model.clear() self.filter.setText("") self.button.setText("Start") + self.button.setEnabled(self.pp_data is not None) if self.pp_data is not None: self.toggle() @@ -325,7 +334,6 @@ def toggle(self): self.button.setText("Pause") self.button.repaint() self.progressBarInit() - self.filter.setText("") self.filter.setEnabled(False) self.start(run, self.compute_score, self.row_for_state, self.iterate_states, self.saved_state, @@ -345,8 +353,8 @@ def on_selection_changed(self, selected): self.selection = [self.model.data(ind) for ind in selected.indexes()[-2:]] self.commit() - def on_filter_changed(self): - self.model.filter(self.filter_text) + def on_filter_changed(self, text): + self.proxy.setFilterFixedString(text) def on_feature_combo_changed(self): self.feature_index = self.feature and self.pp_data.domain.index(self.feature) From 4355650dc6dc80a517d442591c7252f6cfa84295 Mon Sep 17 00:00:00 2001 From: noahnovsak Date: Thu, 6 Oct 2022 15:26:28 +0200 Subject: [PATCH 4/6] fix selection and more tests --- orangecontrib/prototypes/ranktablemodel.py | 19 ++- .../prototypes/widgets/owinteractions_new.py | 36 +++-- .../widgets/tests/test_owinteractions_new.py | 146 +++++++++++++++++- 3 files changed, 184 insertions(+), 17 deletions(-) diff --git a/orangecontrib/prototypes/ranktablemodel.py b/orangecontrib/prototypes/ranktablemodel.py index ffbb3984..3e910918 100644 --- a/orangecontrib/prototypes/ranktablemodel.py +++ b/orangecontrib/prototypes/ranktablemodel.py @@ -19,8 +19,9 @@ class ArrayTableModel(PyTableModel): Other, unlisted methods aren't guaranteed to work and should be used with care. Also requires access to private members of ``AbstractSortTableModel`` directly; - ``__sortInd`` is used for sorting and potential filtering, and ``__init__`` because the - parent implementation immediately wraps a list, which this model does not use. + ``__sortInd`` is needed to append new unsorted data, and ``__init__`` is used + because the parent implementation immediately wraps a list, which this model + does not have. """ def __init__(self, *args, **kwargs): super(PyTableModel, self).__init__(*args, **kwargs) @@ -31,8 +32,12 @@ def __init__(self, *args, **kwargs): self._data = None # type: np.ndarray self._columns = 0 - self._rows = 0 - self._max_display_rows = self._max_data_rows = MAX_ROWS + self._rows = 0 # current number of rows containing data + self._max_view_rows = MAX_ROWS # maximum number of rows the model/view will display + self._max_data_rows = MAX_ROWS # maximum allowed size for the `_data` array + # ``__len__`` returns _rows: amount of existing data in the model + # ``rowCount`` returns the lowest of `_rows` and `_max_view_rows`: + # how large the model/view thinks it is def sortInd(self): return self._AbstractSortTableModel__sortInd @@ -53,7 +58,7 @@ def extendSortFrom(self, sorted_rows: int): self.setSortIndices(indices) def rowCount(self, parent=QModelIndex()): - return 0 if parent.isValid() else min(self._rows, self._max_display_rows) + return 0 if parent.isValid() else min(self._rows, self._max_view_rows) def columnCount(self, parent=QModelIndex()): return 0 if parent.isValid() else self._columns @@ -91,10 +96,10 @@ def append(self, rows: list[list[float]]): if n_rows == 0: return n_data = len(self._data) - insert = self._rows < self._max_display_rows + insert = self._rows < self._max_view_rows if insert: - self.beginInsertRows(QModelIndex(), self._rows, min(self._max_display_rows, self._rows + n_rows) - 1) + self.beginInsertRows(QModelIndex(), self._rows, min(self._max_view_rows, self._rows + n_rows) - 1) if self._rows + n_rows >= n_data: n_data = min(max(n_data + n_rows, 2 * n_data), self._max_data_rows) diff --git a/orangecontrib/prototypes/widgets/owinteractions_new.py b/orangecontrib/prototypes/widgets/owinteractions_new.py index 253bda20..9729a7d3 100644 --- a/orangecontrib/prototypes/widgets/owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/owinteractions_new.py @@ -278,7 +278,7 @@ def __init__(self): def set_data(self, data): self.closeContext() self.clear_messages() - self.selection = [] + self.selection = {} self.data = data self.pp_data = None self.n_attrs = 0 @@ -333,24 +333,40 @@ def toggle(self): if not self.keep_running: self.button.setText("Pause") self.button.repaint() - self.progressBarInit() self.filter.setEnabled(False) + self.progressBarInit() self.start(run, self.compute_score, self.row_for_state, self.iterate_states, self.saved_state, self.progress, self.state_count()) else: self.button.setText("Continue") self.button.repaint() - self.cancel() - self.progressBarFinished() self.filter.setEnabled(True) + self.cancel() + self._stopped() + + def _stopped(self): + self.progressBarFinished() + self._select_default() + + def _select_default(self): + n_rows = self.model.rowCount() + if not n_rows: + return + + if self.selection: + for i in range(n_rows): + names = {self.model.data(self.model.index(i, 2)), + self.model.data(self.model.index(i, 3))} + if names == self.selection: + self.rank_table.selectRow(i) + break - def _select_first_if_none(self): if not self.rank_table.selectedIndexes(): self.rank_table.selectRow(0) def on_selection_changed(self, selected): - self.selection = [self.model.data(ind) for ind in selected.indexes()[-2:]] + self.selection = {self.model.data(ind) for ind in selected.indexes()[-2:]} self.commit() def on_filter_changed(self, text): @@ -396,9 +412,9 @@ def _iterate_by_feature(self, initial_state): yield self.feature_index, j def state_count(self): - if self.feature is None: + if self.feature_index is None: return self.n_attrs * (self.n_attrs - 1) // 2 - return self.n_attrs + return self.n_attrs - 1 def on_partial_result(self, result): add_to_model, latest_state = result @@ -412,7 +428,9 @@ def on_done(self, result): self.button.setText("Finished") self.button.setEnabled(False) self.filter.setEnabled(True) - self._select_first_if_none() + self.keep_running = False + self.saved_state = None + self._stopped() def send_report(self): self.report_table("Interactions", self.rank_table) diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py index 95c66db4..ef98a82c 100644 --- a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py @@ -1,12 +1,18 @@ +import unittest from unittest.mock import Mock import numpy as np +import numpy.testing as npt + +from AnyQt.QtCore import QItemSelection from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.tests.utils import simulate from Orange.widgets.widget import AttributeList -from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions +from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions, Heuristic +from orangecontrib.prototypes.interactions import InteractionScorer class TestOWInteractions(WidgetTest): @@ -126,6 +132,24 @@ def test_input_changed(self): self.process_events() self.widget.commit.assert_called_once() + def test_saved_selection(self): + """Check row selection""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + selection = QItemSelection() + selection.select(self.widget.model.index(2, 0), + self.widget.model.index(2, 3)) + self.widget.on_selection_changed(selection) + settings = self.widget.settingsHandler.pack_data(self.widget) + + w = self.create_widget(OWInteractions, stored_settings=settings) + self.send_signal(self.widget.Inputs.data, self.iris, widget=w) + self.wait_until_finished(w) + self.process_events() + sel_row = w.rank_table.selectionModel().selectedRows()[0].row() + self.assertEqual(sel_row, 2) + def test_feature_combo(self): """Check feature combobox""" feature_combo = self.widget.controls.feature @@ -135,3 +159,123 @@ def test_feature_combo(self): self.wait_until_stop_blocking() self.send_signal(self.widget.Inputs.data, self.zoo) self.assertEqual(len(feature_combo.model()), 17) + + def test_select_feature(self): + """Check feature selection""" + feature_combo = self.widget.controls.feature + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 6) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"sepal width", "sepal length"} + ) + + simulate.combobox_activate_index(feature_combo, 3) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 3) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"petal length", "sepal width"} + ) + + simulate.combobox_activate_index(feature_combo, 0) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 6) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"petal length", "sepal width"} + ) + + def test_send_report(self): + """Check report""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.widget.report_button.click() + self.wait_until_stop_blocking() + self.send_signal(self.widget.Inputs.data, None) + self.widget.report_button.click() + + def test_compute_score(self): + self.widget.scorer = InteractionScorer(self.zoo) + npt.assert_almost_equal(self.widget.compute_score((1, 0)), + [-0.0771, 0.3003, 0.3307], 4) + + def test_row_for_state(self): + row = self.widget.row_for_state((-0.2, 0.2, 0.1), (1, 0)) + self.assertListEqual(row, [-0.2, 0.1, 1, 0]) + + def test_iterate_states(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertListEqual(list(self.widget._iterate_all(None)), + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) + self.assertListEqual(list(self.widget._iterate_all((1, 0))), + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) + self.assertListEqual(list(self.widget._iterate_all((2, 1))), + [(2, 1), (3, 0), (3, 1), (3, 2)]) + self.widget.feature_index = 2 + self.assertListEqual(list(self.widget._iterate_by_feature(None)), + [(2, 0), (2, 1), (2, 3)]) + self.assertListEqual(list(self.widget._iterate_by_feature((2, 0))), + [(2, 0), (2, 1), (2, 3)]) + self.assertListEqual(list(self.widget._iterate_by_feature((2, 3))), + [(2, 3)]) + + def test_state_count(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertEqual(self.widget.state_count(), 6) + self.widget.feature_index = 2 + self.assertEqual(self.widget.state_count(), 3) + + +class TestInteractionScorer(unittest.TestCase): + def test_compute_score(self): + """Check score calculation""" + x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]]) + y = np.array([0, 1, 1, 1]) + domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) + data = Table(domain, x, y) + self.scorer = InteractionScorer(data) + npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4) + npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4) + npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4) + npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4) + + def test_nans(self): + """Check score calculation with nans""" + x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]]) + y = np.array([0, 1, 1, 1, 0, 0, 1]) + domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) + data = Table(domain, x, y) + self.scorer = InteractionScorer(data) + npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4) + npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4) + npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4) + npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4) + + +class TestHeuristic(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.zoo = Table("zoo") + + def test_heuristic(self): + """Check attribute pairs returned by heuristic""" + scorer = InteractionScorer(self.zoo) + heuristic = Heuristic(scorer.information_gain, + type=Heuristic.INFO_GAIN) + self.assertListEqual(list(heuristic.get_states(None))[:9], + [(14, 6), (14, 10), (14, 15), (6, 10), + (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)]) + + states = heuristic.get_states(None) + _ = next(states) + self.assertListEqual(list(heuristic.get_states(next(states)))[:8], + [(14, 10), (14, 15), (6, 10), (14, 5), + (6, 15), (14, 11), (6, 5), (10, 15)]) + + +if __name__ == "__main__": + unittest.main() From c4f9c21a9b0a79e5c25884d26f1ea417a45e1a53 Mon Sep 17 00:00:00 2001 From: noahnovsak Date: Fri, 14 Oct 2022 15:16:38 +0200 Subject: [PATCH 5/6] remove old widget --- .../prototypes/widgets/owinteractions.py | 733 ++++++++++-------- .../prototypes/widgets/owinteractions_new.py | 440 ----------- .../widgets/tests/test_owinteractions.py | 556 ++++++------- .../widgets/tests/test_owinteractions_new.py | 281 ------- 4 files changed, 679 insertions(+), 1331 deletions(-) delete mode 100644 orangecontrib/prototypes/widgets/owinteractions_new.py delete mode 100644 orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py diff --git a/orangecontrib/prototypes/widgets/owinteractions.py b/orangecontrib/prototypes/widgets/owinteractions.py index 3baa2927..1f26f603 100644 --- a/orangecontrib/prototypes/widgets/owinteractions.py +++ b/orangecontrib/prototypes/widgets/owinteractions.py @@ -1,331 +1,440 @@ -""" -Interactions widget -""" -from enum import IntEnum -from operator import attrgetter +import copy from itertools import chain - +from threading import Lock, Timer +from typing import Callable, Optional, Iterable import numpy as np -from AnyQt.QtCore import Qt, QSortFilterProxyModel -from AnyQt.QtCore import QLineF -from AnyQt.QtGui import QStandardItem, QPainter, QColor, QPen -from AnyQt.QtWidgets import QHeaderView -from AnyQt.QtCore import QModelIndex -from AnyQt.QtWidgets import QStyleOptionViewItem, QApplication, QStyle +from AnyQt.QtGui import QColor, QPainter, QPen +from AnyQt.QtCore import QModelIndex, Qt, QLineF, QSortFilterProxyModel +from AnyQt.QtWidgets import QTableView, QHeaderView, \ + QStyleOptionViewItem, QApplication, QStyle, QLineEdit -from Orange.data import Table, Domain, ContinuousVariable, StringVariable +from Orange.data import Table, Variable +from Orange.preprocess import Discretize, Remove from Orange.widgets import gui -from Orange.widgets.settings import Setting -from Orange.widgets.utils.itemmodels import DomainModel -from Orange.widgets.utils.signals import Input, Output -from Orange.widgets.utils.widgetpreview import WidgetPreview from Orange.widgets.widget import OWWidget, AttributeList, Msg -from Orange.widgets.visualize.utils import VizRankDialogAttrPair -from Orange.preprocess import Discretize, Remove -import Orange.widgets.data.owcorrelations +from Orange.widgets.utils.widgetpreview import WidgetPreview +from Orange.widgets.utils.signals import Input, Output +from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState +from Orange.widgets.utils.itemmodels import DomainModel +from Orange.widgets.settings import Setting, ContextSetting, DomainContextHandler +from orangecontrib.prototypes.ranktablemodel import RankModel from orangecontrib.prototypes.interactions import InteractionScorer -SIZE_LIMIT = 1000000 - - -class HeuristicType(IntEnum): - """ - Heuristic type enumerator. Possible choices: low Information gain first, random. - """ - INFOGAIN, RANDOM = 0, 1 - - @staticmethod - def items(): - """ - Text for heuristic types. Can be used in gui controls (eg. combobox). - """ - return ["InfoGain Heuristic", "Random Search"] +class ModelQueue: + """ + Another queueing object, similar to ``queue.Queue``. + The main difference is that ``get()`` returns all its + contents at the same time, instead of one by one. + """ + def __init__(self): + self.lock = Lock() + self.model = [] + self.state = None + + def put(self, row, state): + with self.lock: + self.model.append(row) + self.state = state + + def get(self): + with self.lock: + model, self.model = self.model, [] + state, self.state = self.state, None + return model, state + + +def run(compute_score: Callable, row_for_state: Callable, + iterate_states: Callable, saved_state: Optional[Iterable], + progress: int, state_count: int, task: TaskState): + """ + Replaces ``run_vizrank``, with some minor adjustments. + - ``ModelQueue`` replaces ``queue.Queue`` + - `row_for_state` parameter added + - `scores` parameter removed + """ + task.set_status("Getting combinations...") + task.set_progress_value(0.1) + states = iterate_states(saved_state) + + task.set_status("Getting scores...") + queue = ModelQueue() + can_set_partial_result = True + + def do_work(st, next_st): + try: + score = compute_score(st) + if score is not None: + queue.put(row_for_state(score, st), next_st) + except Exception: + pass + + def reset_flag(): + nonlocal can_set_partial_result + can_set_partial_result = True + + state = None + next_state = next(states) + try: + while True: + if task.is_interruption_requested(): + return queue.get() + task.set_progress_value(progress * 100 // state_count) + progress += 1 + state = copy.copy(next_state) + next_state = copy.copy(next(states)) + do_work(state, next_state) + # for simple scores (e.g. correlations widget) and many feature + # combinations, the 'partial_result_ready' signal (emitted by + # invoking 'task.set_partial_result') was emitted too frequently + # for a longer period of time and therefore causing the widget + # being unresponsive + if can_set_partial_result: + task.set_partial_result(queue.get()) + can_set_partial_result = False + Timer(0.05, reset_flag).start() + except StopIteration: + do_work(state, None) + task.set_partial_result(queue.get()) + return queue.get() class Heuristic: - def __init__(self, weights, heuristic_type=None): - self.n_attributes = len(weights) - self.attributes = np.arange(self.n_attributes) - if heuristic_type == HeuristicType.INFOGAIN: - self.attributes = self.attributes[np.argsort(weights)] - elif heuristic_type == HeuristicType.RANDOM: - np.random.shuffle(self.attributes) - - def generate_states(self): - # prioritize two mid ranked attributes over highest first - for s in range(1, self.n_attributes * (self.n_attributes - 1) // 2): - for i in range(max(s - self.n_attributes + 1, 0), (s + 1) // 2): - yield self.attributes[i], self.attributes[s - i] - - def get_states(self, initial_state): - states = self.generate_states() - if initial_state is not None: - while next(states) != initial_state: - pass - return chain([initial_state], states) - return states + RANDOM, INFO_GAIN = 0, 1 + type = {RANDOM: "Random Search", + INFO_GAIN: "Information Gain Heuristic"} + + def __init__(self, weights, type=None): + self.n_attributes = len(weights) + self.attributes = np.arange(self.n_attributes) + if type == self.RANDOM: + np.random.shuffle(self.attributes) + if type == self.INFO_GAIN: + self.attributes = self.attributes[np.argsort(weights)] + + def generate_states(self): + # prioritize two mid ranked attributes over highest first + for s in range(1, self.n_attributes * (self.n_attributes - 1) // 2): + for i in range(max(s - self.n_attributes + 1, 0), (s + 1) // 2): + yield self.attributes[i], self.attributes[s - i] + + def get_states(self, initial_state): + states = self.generate_states() + if initial_state is not None: + while next(states) != initial_state: + pass + return chain([initial_state], states) + return states class InteractionItemDelegate(gui.TableBarItem): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.r = QColor(255, 170, 127) - self.g = QColor(170, 242, 43) - self.b = QColor(70, 190, 250) - self.__line = QLineF() - self.__pen = QPen(self.b, 5, Qt.SolidLine, Qt.RoundCap) - - def paint( - self, painter: QPainter, option: QStyleOptionViewItem, - index: QModelIndex - ) -> None: - opt = QStyleOptionViewItem(option) - self.initStyleOption(opt, index) - widget = option.widget - style = QApplication.style() if widget is None else widget.style() - pen = self.__pen - line = self.__line - self.__style = style - text = opt.text - opt.text = "" - style.drawControl(QStyle.CE_ItemViewItem, opt, painter, widget) - textrect = style.subElementRect( - QStyle.SE_ItemViewItemText, opt, widget) - - # interaction is None for attribute items -> - # only draw bars for first column - interaction = self.cachedData(index, InteractionRank.IntRole) - if interaction is not None: - rect = option.rect - pw = self.penWidth - textoffset = pw + 2 - baseline = rect.bottom() - textoffset / 2 - origin = rect.left() + 3 + pw / 2 # + half pen width for the round line cap - width = rect.width() - 3 - pw - - def draw_line(start, length): - line.setLine(origin + start, baseline, origin + start + length, baseline) - painter.drawLine(line) - - # negative information gains stem from issues in interaction calculation - # may cause bars reaching out of intended area - l_bar, r_bar = self.cachedData(index, InteractionRank.GainRole) - l_bar, r_bar = width * max(l_bar, 0), width * max(r_bar, 0) - interaction *= width - - pen.setColor(self.b) - pen.setWidth(pw) - painter.save() - painter.setRenderHint(QPainter.Antialiasing) - painter.setPen(pen) - draw_line(0, l_bar) - draw_line(l_bar + interaction, r_bar) - pen.setColor(self.g if interaction >= 0 else self.r) - painter.setPen(pen) - draw_line(l_bar, interaction) - painter.restore() - textrect.adjust(0, 0, 0, -textoffset) - - opt.text = text - self.drawViewItemText(style, painter, opt, textrect) - - -class SortProxyModel(QSortFilterProxyModel): - def lessThan(self, left, right): - role = self.sortRole() - l_score = left.data(role) - r_score = right.data(role) - if l_score[-1] == "%": - l_score, r_score = float(l_score[:-1]), float(r_score[:-1]) - return l_score < r_score - - -class InteractionRank(Orange.widgets.data.owcorrelations.CorrelationRank): - IntRole = next(gui.OrangeUserRole) - GainRole = next(gui.OrangeUserRole) - - def __init__(self, *args): - VizRankDialogAttrPair.__init__(self, *args) - self.scorer = None - self.heuristic = None - self.use_heuristic = False - self.sel_feature_index = None - - self.model_proxy = SortProxyModel(self) - self.model_proxy.setSourceModel(self.rank_model) - self.rank_table.setModel(self.model_proxy) - self.rank_table.selectionModel().selectionChanged.connect(self.on_selection_changed) - self.rank_table.setItemDelegate(InteractionItemDelegate()) - self.rank_table.setSortingEnabled(True) - self.rank_table.sortByColumn(0, Qt.DescendingOrder) - self.rank_table.horizontalHeader().setStretchLastSection(False) - self.rank_table.horizontalHeader().show() - - def initialize(self): - VizRankDialogAttrPair.initialize(self) - data = self.master.disc_data - self.attrs = data and data.domain.attributes - self.model_proxy.setFilterKeyColumn(-1) - self.rank_table.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) - self.rank_model.setHorizontalHeaderLabels(["Interaction", "Info Gain", "Feature 1", "Feature 2"]) - self.heuristic = None - self.use_heuristic = False - self.sel_feature_index = self.master.feature and data.domain.index(self.master.feature) - if data: - if self.scorer is None or self.scorer.data != data: - self.scorer = InteractionScorer(data) - self.use_heuristic = len(data) * len(self.attrs) ** 2 > SIZE_LIMIT - if self.use_heuristic and not self.sel_feature_index: - self.heuristic = Heuristic(self.scorer.information_gain, self.master.heuristic_type) - - def compute_score(self, state): - scores = (self.scorer(*state), - self.scorer.information_gain[state[0]], - self.scorer.information_gain[state[1]]) - return tuple(self.scorer.normalize(score) for score in scores) - - def row_for_state(self, score, state): - attrs = sorted((self.attrs[x] for x in state), key=attrgetter("name")) - attr_items = [] - for attr in attrs: - item = QStandardItem(attr.name) - item.setToolTip(attr.name) - attr_items.append(item) - score_items = [ - QStandardItem("{:+.1f}%".format(100 * score[0])), - QStandardItem("{:.1f}%".format(100 * sum(score))) - ] - score_items[0].setData(score[0], self.IntRole) - # arrange bars to match columns - gains = [x[1] for x in sorted(enumerate(score[1:]), key=lambda x: self.attrs[state[x[0]]].name)] - score_items[0].setData(gains, self.GainRole) - score_items[0].setToolTip("{}: {:+.1f}%\n{}: {:+.1f}%".format(attrs[0], 100*gains[0], attrs[1], 100*gains[1])) - for item in score_items + attr_items: - item.setData(attrs, self._AttrRole) - item.setData(Qt.AlignLeft + Qt.AlignCenter, Qt.TextAlignmentRole) - return score_items + attr_items - - def check_preconditions(self): - return self.master.disc_data is not None - - -class OWInteractions(Orange.widgets.data.owcorrelations.OWCorrelations): - # todo: make parent class for OWInteractions and OWCorrelations - name = "Interactions" - description = "Compute all pairwise attribute interactions." - category = None - icon = "icons/Interactions.svg" - - class Inputs: - data = Input("Data", Table) - - class Outputs: - features = Output("Features", AttributeList) - interactions = Output("Interactions", Table) - - # feature and selection set by parent - heuristic_type: int - heuristic_type = Setting(0) - - class Warning(OWWidget.Warning): - not_enough_vars = Msg("At least two features are needed.") - not_enough_inst = Msg("At least two instances are needed.") - no_class_var = Msg("Target feature missing") - - def __init__(self): - OWWidget.__init__(self) - self.data = None # type: Table - self.disc_data = None # type: Table - - # GUI - box = gui.vBox(self.controlArea) - self.heuristic_combo = gui.comboBox( - box, self, "heuristic_type", items=HeuristicType.items(), - orientation=Qt.Horizontal, callback=self._heuristic_combo_changed - ) - - self.feature_model = DomainModel( - order=DomainModel.ATTRIBUTES, separators=False, - placeholder="(All combinations)") - gui.comboBox( - box, self, "feature", callback=self._feature_combo_changed, - model=self.feature_model, searchable=True - ) - - self.vizrank, _ = InteractionRank.add_vizrank( - None, self, None, self._vizrank_selection_changed) - self.vizrank.button.setEnabled(False) - self.vizrank.threadStopped.connect(self._vizrank_stopped) - - box.layout().addWidget(self.vizrank.filter) - box.layout().addWidget(self.vizrank.rank_table) - box.layout().addWidget(self.vizrank.button) - - def _heuristic_combo_changed(self): - self.apply() - - @Inputs.data - def set_data(self, data): - self.closeContext() - self.clear_messages() - self.data = data - self.disc_data = None - self.selection = [] - if data is not None: - if len(data) < 2: - self.Warning.not_enough_inst() - elif data.Y.size == 0: - self.Warning.no_class_var() - else: - remover = Remove(Remove.RemoveConstant) - data = remover(data) - disc_data = Discretize()(data) - if remover.attr_results["removed"]: - self.Information.removed_cons_feat() - if len(disc_data.domain.attributes) < 2: - self.Warning.not_enough_vars() - else: - self.disc_data = disc_data - self.feature_model.set_domain(self.disc_data and self.disc_data.domain) - self.openContext(self.disc_data) - self.apply() - self.vizrank.button.setEnabled(self.disc_data is not None) - - def apply(self): - self.vizrank.initialize() - if self.disc_data is not None: - # this triggers self.commit() by changing vizrank selection - self.vizrank.toggle() - else: - self.commit() - - def commit(self): - if self.data is None or self.disc_data is None: - self.Outputs.features.send(None) - self.Outputs.interactions.send(None) - return - - attrs = [ContinuousVariable("Interaction")] - metas = [StringVariable("Feature 1"), StringVariable("Feature 2")] - domain = Domain(attrs, metas=metas) - model = self.vizrank.rank_model - x = np.array([ - [float(model.data(model.index(row, 0), InteractionRank.IntRole))] - for row in range(model.rowCount())]) - m = np.array( - [[a.name for a in model.data(model.index(row, 0), InteractionRank._AttrRole)] - for row in range(model.rowCount())], dtype=object) - int_table = Table(domain, x, metas=m) - int_table.name = "Interactions" - - # data has been imputed; send original attributes - self.Outputs.features.send(AttributeList( - [self.data.domain[var.name] for var in self.selection])) - self.Outputs.interactions.send(int_table) + def paint(self, painter: QPainter, option: QStyleOptionViewItem, + index: QModelIndex) -> None: + opt = QStyleOptionViewItem(option) + self.initStyleOption(opt, index) + widget = option.widget + style = QApplication.style() if widget is None else widget.style() + pen = QPen(QColor("#46befa"), 5, Qt.SolidLine, Qt.RoundCap) + line = QLineF() + self.__style = style + text = opt.text + opt.text = "" + style.drawControl(QStyle.CE_ItemViewItem, opt, painter, widget) + textrect = style.subElementRect( + QStyle.SE_ItemViewItemText, opt, widget) + + interaction = self.cachedData(index, Qt.EditRole) + # only draw bars for first column + if index.column() == 0 and interaction is not None: + rect = option.rect + pw = self.penWidth + textoffset = pw + 2 + baseline = rect.bottom() - textoffset / 2 + origin = rect.left() + 3 + pw / 2 # + half pen width for the round line cap + width = rect.width() - 3 - pw + + def draw_line(start, length): + line.setLine(origin + start, baseline, origin + start + length, baseline) + painter.drawLine(line) + + scorer = index.model().scorer + attr1 = self.cachedData(index.siblingAtColumn(2), Qt.EditRole) + attr2 = self.cachedData(index.siblingAtColumn(3), Qt.EditRole) + l_bar = scorer.normalize(scorer.information_gain[int(attr1)]) + r_bar = scorer.normalize(scorer.information_gain[int(attr2)]) + # negative information gains stem from issues in interaction + # calculation and may cause bars reaching out of intended area + l_bar, r_bar = width * max(l_bar, 0), width * max(r_bar, 0) + interaction *= width + + pen.setWidth(pw) + painter.save() + painter.setRenderHint(QPainter.Antialiasing) + painter.setPen(pen) + draw_line(0, l_bar) + draw_line(l_bar + interaction, r_bar) + pen.setColor(QColor("#aaf22b") if interaction >= 0 else QColor("#ffaa7f")) + painter.setPen(pen) + draw_line(l_bar, interaction) + painter.restore() + textrect.adjust(0, 0, 0, -textoffset) + + opt.text = text + self.drawViewItemText(style, painter, opt, textrect) + + +class FilterProxy(QSortFilterProxyModel): + scorer = None + + def sort(self, *args, **kwargs): + self.sourceModel().sort(*args, **kwargs) + + +class OWInteractions(OWWidget, ConcurrentWidgetMixin): + name = "Interactions" + description = "Compute all pairwise attribute interactions." + icon = "icons/Interactions.svg" + category = "Unsupervised" + + class Inputs: + data = Input("Data", Table) + + class Outputs: + features = Output("Features", AttributeList) + + settingsHandler = DomainContextHandler() + selection = ContextSetting([]) + feature: Variable + feature = ContextSetting(None) + heuristic_type: int + heuristic_type = Setting(0) + + want_main_area = False + want_control_area = True + + class Information(OWWidget.Information): + removed_cons_feat = Msg("Constant features have been removed.") + + class Warning(OWWidget.Warning): + not_enough_vars = Msg("At least two features are needed.") + not_enough_inst = Msg("At least two instances are needed.") + no_class_var = Msg("Target feature missing") + + def __init__(self): + OWWidget.__init__(self) + ConcurrentWidgetMixin.__init__(self) + + self.keep_running = True + self.saved_state = None + self.progress = 0 + + self.data = None # type: Table + self.pp_data = None # type: Table + self.n_attrs = 0 + + self.scorer = None + self.heuristic = None + self.feature_index = None + + gui.comboBox(self.controlArea, self, "heuristic_type", + items=Heuristic.type.values(), + callback=self.on_heuristic_combo_changed,) + + self.feature_model = DomainModel(order=DomainModel.ATTRIBUTES, + separators=False, + placeholder="(All combinations)") + gui.comboBox(self.controlArea, self, "feature", + callback=self.on_feature_combo_changed, + model=self.feature_model, searchable=True) + + self.filter = QLineEdit() + self.filter.setPlaceholderText("Filter ...") + self.filter.textChanged.connect(self.on_filter_changed) + self.controlArea.layout().addWidget(self.filter) + + self.model = RankModel() + self.model.setHorizontalHeaderLabels(["Interaction", "Information Gain", + "Feature 1", "Feature 2"]) + self.proxy = FilterProxy(filterCaseSensitivity=Qt.CaseInsensitive) + self.proxy.setSourceModel(self.model) + self.proxy.setFilterKeyColumn(-1) + self.rank_table = view = QTableView(selectionBehavior=QTableView.SelectRows, + selectionMode=QTableView.SingleSelection, + showGrid=False, + editTriggers=gui.TableView.NoEditTriggers) + view.setSortingEnabled(True) + view.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) + view.setItemDelegate(InteractionItemDelegate()) + view.setModel(self.proxy) + view.selectionModel().selectionChanged.connect(self.on_selection_changed) + self.controlArea.layout().addWidget(view) + + self.button = gui.button(self.controlArea, self, "Start", callback=self.toggle) + self.button.setEnabled(False) + + @Inputs.data + def set_data(self, data): + self.closeContext() + self.clear_messages() + self.selection = {} + self.data = data + self.pp_data = None + self.n_attrs = 0 + if data is not None: + if len(data) < 2: + self.Warning.not_enough_inst() + elif data.Y.size == 0: + self.Warning.no_class_var() + else: + remover = Remove(Remove.RemoveConstant) + pp_data = Discretize()(remover(data)) + if remover.attr_results["removed"]: + self.Information.removed_cons_feat() + if len(pp_data.domain.attributes) < 2: + self.Warning.not_enough_vars() + else: + self.pp_data = pp_data + self.n_attrs = len(pp_data.domain.attributes) + self.scorer = InteractionScorer(pp_data) + self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) + self.model.set_domain(pp_data.domain) + self.proxy.scorer = self.scorer + self.feature_model.set_domain(self.pp_data and self.pp_data.domain) + self.openContext(self.pp_data) + self.initialize() + + def initialize(self): + if self.task is not None: + self.keep_running = False + self.cancel() + self.keep_running = True + self.saved_state = None + self.progress = 0 + self.progressBarFinished() + self.model.clear() + self.filter.setText("") + self.button.setText("Start") + self.button.setEnabled(self.pp_data is not None) + if self.pp_data is not None: + self.toggle() + + def commit(self): + if self.data is None: + self.Outputs.features.send(None) + return + + self.Outputs.features.send(AttributeList( + [self.data.domain[attr] for attr in self.selection])) + + def toggle(self): + self.keep_running = not self.keep_running + if not self.keep_running: + self.button.setText("Pause") + self.button.repaint() + self.filter.setEnabled(False) + self.progressBarInit() + self.start(run, self.compute_score, self.row_for_state, + self.iterate_states, self.saved_state, + self.progress, self.state_count()) + else: + self.button.setText("Continue") + self.button.repaint() + self.filter.setEnabled(True) + self.cancel() + self._stopped() + + def _stopped(self): + self.progressBarFinished() + self._select_default() + + def _select_default(self): + n_rows = self.model.rowCount() + if not n_rows: + return + + if self.selection: + for i in range(n_rows): + names = {self.model.data(self.model.index(i, 2)), + self.model.data(self.model.index(i, 3))} + if names == self.selection: + self.rank_table.selectRow(i) + break + + if not self.rank_table.selectedIndexes(): + self.rank_table.selectRow(0) + + def on_selection_changed(self, selected): + self.selection = {self.model.data(ind) for ind in selected.indexes()[-2:]} + self.commit() + + def on_filter_changed(self, text): + self.proxy.setFilterFixedString(text) + + def on_feature_combo_changed(self): + self.feature_index = self.feature and self.pp_data.domain.index(self.feature) + self.initialize() + + def on_heuristic_combo_changed(self): + if self.pp_data is not None: + self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) + self.initialize() + + def compute_score(self, state): + scores = (self.scorer(*state), + self.scorer.information_gain[state[0]], + self.scorer.information_gain[state[1]]) + return tuple(self.scorer.normalize(score) for score in scores) + + @staticmethod + def row_for_state(score, state): + return [score[0], sum(score)] + list(state) + + def iterate_states(self, initial_state): + if self.feature is not None: + return self._iterate_by_feature(initial_state) + if self.n_attrs > 3 and self.heuristic is not None: + return self.heuristic.get_states(initial_state) + return self._iterate_all(initial_state) + + def _iterate_all(self, initial_state): + i0, j0 = initial_state or (0, 0) + for i in range(i0, self.n_attrs): + for j in range(j0, i): + yield i, j + j0 = 0 + + def _iterate_by_feature(self, initial_state): + _, j0 = initial_state or (0, 0) + for j in range(j0, self.n_attrs): + if j != self.feature_index: + yield self.feature_index, j + + def state_count(self): + if self.feature_index is None: + return self.n_attrs * (self.n_attrs - 1) // 2 + return self.n_attrs - 1 + + def on_partial_result(self, result): + add_to_model, latest_state = result + if add_to_model: + self.saved_state = latest_state + self.model.append(add_to_model) + self.progress = len(self.model) + self.progressBarSet(self.progress * 100 // self.state_count()) + + def on_done(self, result): + self.button.setText("Finished") + self.button.setEnabled(False) + self.filter.setEnabled(True) + self.keep_running = False + self.saved_state = None + self._stopped() + + def send_report(self): + self.report_table("Interactions", self.rank_table) if __name__ == "__main__": # pragma: no cover - WidgetPreview(OWInteractions).run(Table("iris")) + WidgetPreview(OWInteractions).run(Table("iris")) diff --git a/orangecontrib/prototypes/widgets/owinteractions_new.py b/orangecontrib/prototypes/widgets/owinteractions_new.py deleted file mode 100644 index 9729a7d3..00000000 --- a/orangecontrib/prototypes/widgets/owinteractions_new.py +++ /dev/null @@ -1,440 +0,0 @@ -import copy -from itertools import chain -from threading import Lock, Timer -from typing import Callable, Optional, Iterable -import numpy as np - -from AnyQt.QtGui import QColor, QPainter, QPen -from AnyQt.QtCore import QModelIndex, Qt, QLineF, QSortFilterProxyModel -from AnyQt.QtWidgets import QTableView, QHeaderView, \ - QStyleOptionViewItem, QApplication, QStyle, QLineEdit - -from Orange.data import Table, Variable -from Orange.preprocess import Discretize, Remove -from Orange.widgets import gui -from Orange.widgets.widget import OWWidget, AttributeList, Msg -from Orange.widgets.utils.widgetpreview import WidgetPreview -from Orange.widgets.utils.signals import Input, Output -from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState -from Orange.widgets.utils.itemmodels import DomainModel -from Orange.widgets.settings import Setting, ContextSetting, DomainContextHandler - -from orangecontrib.prototypes.ranktablemodel import RankModel -from orangecontrib.prototypes.interactions import InteractionScorer - - -class ModelQueue: - """ - Another queueing object, similar to ``queue.Queue``. - The main difference is that ``get()`` returns all its - contents at the same time, instead of one by one. - """ - def __init__(self): - self.lock = Lock() - self.model = [] - self.state = None - - def put(self, row, state): - with self.lock: - self.model.append(row) - self.state = state - - def get(self): - with self.lock: - model, self.model = self.model, [] - state, self.state = self.state, None - return model, state - - -def run(compute_score: Callable, row_for_state: Callable, - iterate_states: Callable, saved_state: Optional[Iterable], - progress: int, state_count: int, task: TaskState): - """ - Replaces ``run_vizrank``, with some minor adjustments. - - ``ModelQueue`` replaces ``queue.Queue`` - - `row_for_state` parameter added - - `scores` parameter removed - """ - task.set_status("Getting combinations...") - task.set_progress_value(0.1) - states = iterate_states(saved_state) - - task.set_status("Getting scores...") - queue = ModelQueue() - can_set_partial_result = True - - def do_work(st, next_st): - try: - score = compute_score(st) - if score is not None: - queue.put(row_for_state(score, st), next_st) - except Exception: - pass - - def reset_flag(): - nonlocal can_set_partial_result - can_set_partial_result = True - - state = None - next_state = next(states) - try: - while True: - if task.is_interruption_requested(): - return queue.get() - task.set_progress_value(progress * 100 // state_count) - progress += 1 - state = copy.copy(next_state) - next_state = copy.copy(next(states)) - do_work(state, next_state) - # for simple scores (e.g. correlations widget) and many feature - # combinations, the 'partial_result_ready' signal (emitted by - # invoking 'task.set_partial_result') was emitted too frequently - # for a longer period of time and therefore causing the widget - # being unresponsive - if can_set_partial_result: - task.set_partial_result(queue.get()) - can_set_partial_result = False - Timer(0.05, reset_flag).start() - except StopIteration: - do_work(state, None) - task.set_partial_result(queue.get()) - return queue.get() - - -class Heuristic: - RANDOM, INFO_GAIN = 0, 1 - type = {RANDOM: "Random Search", - INFO_GAIN: "Information Gain Heuristic"} - - def __init__(self, weights, type=None): - self.n_attributes = len(weights) - self.attributes = np.arange(self.n_attributes) - if type == self.RANDOM: - np.random.shuffle(self.attributes) - if type == self.INFO_GAIN: - self.attributes = self.attributes[np.argsort(weights)] - - def generate_states(self): - # prioritize two mid ranked attributes over highest first - for s in range(1, self.n_attributes * (self.n_attributes - 1) // 2): - for i in range(max(s - self.n_attributes + 1, 0), (s + 1) // 2): - yield self.attributes[i], self.attributes[s - i] - - def get_states(self, initial_state): - states = self.generate_states() - if initial_state is not None: - while next(states) != initial_state: - pass - return chain([initial_state], states) - return states - - -class InteractionItemDelegate(gui.TableBarItem): - def paint(self, painter: QPainter, option: QStyleOptionViewItem, - index: QModelIndex) -> None: - opt = QStyleOptionViewItem(option) - self.initStyleOption(opt, index) - widget = option.widget - style = QApplication.style() if widget is None else widget.style() - pen = QPen(QColor("#46befa"), 5, Qt.SolidLine, Qt.RoundCap) - line = QLineF() - self.__style = style - text = opt.text - opt.text = "" - style.drawControl(QStyle.CE_ItemViewItem, opt, painter, widget) - textrect = style.subElementRect( - QStyle.SE_ItemViewItemText, opt, widget) - - interaction = self.cachedData(index, Qt.EditRole) - # only draw bars for first column - if index.column() == 0 and interaction is not None: - rect = option.rect - pw = self.penWidth - textoffset = pw + 2 - baseline = rect.bottom() - textoffset / 2 - origin = rect.left() + 3 + pw / 2 # + half pen width for the round line cap - width = rect.width() - 3 - pw - - def draw_line(start, length): - line.setLine(origin + start, baseline, origin + start + length, baseline) - painter.drawLine(line) - - scorer = index.model().scorer - attr1 = self.cachedData(index.siblingAtColumn(2), Qt.EditRole) - attr2 = self.cachedData(index.siblingAtColumn(3), Qt.EditRole) - l_bar = scorer.normalize(scorer.information_gain[int(attr1)]) - r_bar = scorer.normalize(scorer.information_gain[int(attr2)]) - # negative information gains stem from issues in interaction - # calculation and may cause bars reaching out of intended area - l_bar, r_bar = width * max(l_bar, 0), width * max(r_bar, 0) - interaction *= width - - pen.setWidth(pw) - painter.save() - painter.setRenderHint(QPainter.Antialiasing) - painter.setPen(pen) - draw_line(0, l_bar) - draw_line(l_bar + interaction, r_bar) - pen.setColor(QColor("#aaf22b") if interaction >= 0 else QColor("#ffaa7f")) - painter.setPen(pen) - draw_line(l_bar, interaction) - painter.restore() - textrect.adjust(0, 0, 0, -textoffset) - - opt.text = text - self.drawViewItemText(style, painter, opt, textrect) - - -class FilterProxy(QSortFilterProxyModel): - scorer = None - - def sort(self, *args, **kwargs): - self.sourceModel().sort(*args, **kwargs) - - -class OWInteractions(OWWidget, ConcurrentWidgetMixin): - name = "Interactions New" - description = "Compute all pairwise attribute interactions." - icon = "icons/Interactions.svg" - category = "Unsupervised" - - class Inputs: - data = Input("Data", Table) - - class Outputs: - features = Output("Features", AttributeList) - - settingsHandler = DomainContextHandler() - selection = ContextSetting([]) - feature: Variable - feature = ContextSetting(None) - heuristic_type: int - heuristic_type = Setting(0) - - want_main_area = False - want_control_area = True - - class Information(OWWidget.Information): - removed_cons_feat = Msg("Constant features have been removed.") - - class Warning(OWWidget.Warning): - not_enough_vars = Msg("At least two features are needed.") - not_enough_inst = Msg("At least two instances are needed.") - no_class_var = Msg("Target feature missing") - - def __init__(self): - OWWidget.__init__(self) - ConcurrentWidgetMixin.__init__(self) - - self.keep_running = True - self.saved_state = None - self.progress = 0 - - self.data = None # type: Table - self.pp_data = None # type: Table - self.n_attrs = 0 - - self.scorer = None - self.heuristic = None - self.feature_index = None - - gui.comboBox(self.controlArea, self, "heuristic_type", - items=Heuristic.type.values(), - callback=self.on_heuristic_combo_changed,) - - self.feature_model = DomainModel(order=DomainModel.ATTRIBUTES, - separators=False, - placeholder="(All combinations)") - gui.comboBox(self.controlArea, self, "feature", - callback=self.on_feature_combo_changed, - model=self.feature_model, searchable=True) - - self.filter = QLineEdit() - self.filter.setPlaceholderText("Filter ...") - self.filter.textChanged.connect(self.on_filter_changed) - self.controlArea.layout().addWidget(self.filter) - - self.model = RankModel() - self.model.setHorizontalHeaderLabels(["Interaction", "Information Gain", - "Feature 1", "Feature 2"]) - self.proxy = FilterProxy(filterCaseSensitivity=Qt.CaseInsensitive) - self.proxy.setSourceModel(self.model) - self.proxy.setFilterKeyColumn(-1) - self.rank_table = view = QTableView(selectionBehavior=QTableView.SelectRows, - selectionMode=QTableView.SingleSelection, - showGrid=False, - editTriggers=gui.TableView.NoEditTriggers) - view.setSortingEnabled(True) - view.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) - view.setItemDelegate(InteractionItemDelegate()) - view.setModel(self.proxy) - view.selectionModel().selectionChanged.connect(self.on_selection_changed) - self.controlArea.layout().addWidget(view) - - self.button = gui.button(self.controlArea, self, "Start", callback=self.toggle) - self.button.setEnabled(False) - - @Inputs.data - def set_data(self, data): - self.closeContext() - self.clear_messages() - self.selection = {} - self.data = data - self.pp_data = None - self.n_attrs = 0 - if data is not None: - if len(data) < 2: - self.Warning.not_enough_inst() - elif data.Y.size == 0: - self.Warning.no_class_var() - else: - remover = Remove(Remove.RemoveConstant) - pp_data = Discretize()(remover(data)) - if remover.attr_results["removed"]: - self.Information.removed_cons_feat() - if len(pp_data.domain.attributes) < 2: - self.Warning.not_enough_vars() - else: - self.pp_data = pp_data - self.n_attrs = len(pp_data.domain.attributes) - self.scorer = InteractionScorer(pp_data) - self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) - self.model.set_domain(pp_data.domain) - self.proxy.scorer = self.scorer - self.feature_model.set_domain(self.pp_data and self.pp_data.domain) - self.openContext(self.pp_data) - self.initialize() - - def initialize(self): - if self.task is not None: - self.keep_running = False - self.cancel() - self.keep_running = True - self.saved_state = None - self.progress = 0 - self.progressBarFinished() - self.model.clear() - self.filter.setText("") - self.button.setText("Start") - self.button.setEnabled(self.pp_data is not None) - if self.pp_data is not None: - self.toggle() - - def commit(self): - if self.data is None: - self.Outputs.features.send(None) - return - - self.Outputs.features.send(AttributeList( - [self.data.domain[attr] for attr in self.selection])) - - def toggle(self): - self.keep_running = not self.keep_running - if not self.keep_running: - self.button.setText("Pause") - self.button.repaint() - self.filter.setEnabled(False) - self.progressBarInit() - self.start(run, self.compute_score, self.row_for_state, - self.iterate_states, self.saved_state, - self.progress, self.state_count()) - else: - self.button.setText("Continue") - self.button.repaint() - self.filter.setEnabled(True) - self.cancel() - self._stopped() - - def _stopped(self): - self.progressBarFinished() - self._select_default() - - def _select_default(self): - n_rows = self.model.rowCount() - if not n_rows: - return - - if self.selection: - for i in range(n_rows): - names = {self.model.data(self.model.index(i, 2)), - self.model.data(self.model.index(i, 3))} - if names == self.selection: - self.rank_table.selectRow(i) - break - - if not self.rank_table.selectedIndexes(): - self.rank_table.selectRow(0) - - def on_selection_changed(self, selected): - self.selection = {self.model.data(ind) for ind in selected.indexes()[-2:]} - self.commit() - - def on_filter_changed(self, text): - self.proxy.setFilterFixedString(text) - - def on_feature_combo_changed(self): - self.feature_index = self.feature and self.pp_data.domain.index(self.feature) - self.initialize() - - def on_heuristic_combo_changed(self): - if self.pp_data is not None: - self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) - self.initialize() - - def compute_score(self, state): - scores = (self.scorer(*state), - self.scorer.information_gain[state[0]], - self.scorer.information_gain[state[1]]) - return tuple(self.scorer.normalize(score) for score in scores) - - @staticmethod - def row_for_state(score, state): - return [score[0], sum(score)] + list(state) - - def iterate_states(self, initial_state): - if self.feature is not None: - return self._iterate_by_feature(initial_state) - if self.n_attrs > 3 and self.heuristic is not None: - return self.heuristic.get_states(initial_state) - return self._iterate_all(initial_state) - - def _iterate_all(self, initial_state): - i0, j0 = initial_state or (0, 0) - for i in range(i0, self.n_attrs): - for j in range(j0, i): - yield i, j - j0 = 0 - - def _iterate_by_feature(self, initial_state): - _, j0 = initial_state or (0, 0) - for j in range(j0, self.n_attrs): - if j != self.feature_index: - yield self.feature_index, j - - def state_count(self): - if self.feature_index is None: - return self.n_attrs * (self.n_attrs - 1) // 2 - return self.n_attrs - 1 - - def on_partial_result(self, result): - add_to_model, latest_state = result - if add_to_model: - self.saved_state = latest_state - self.model.append(add_to_model) - self.progress = len(self.model) - self.progressBarSet(self.progress * 100 // self.state_count()) - - def on_done(self, result): - self.button.setText("Finished") - self.button.setEnabled(False) - self.filter.setEnabled(True) - self.keep_running = False - self.saved_state = None - self._stopped() - - def send_report(self): - self.report_table("Interactions", self.rank_table) - - -if __name__ == "__main__": # pragma: no cover - WidgetPreview(OWInteractions).run(Table("iris")) diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions.py index c854e601..b5ebd8fa 100644 --- a/orangecontrib/prototypes/widgets/tests/test_owinteractions.py +++ b/orangecontrib/prototypes/widgets/tests/test_owinteractions.py @@ -1,321 +1,281 @@ import unittest -from unittest.mock import patch, Mock +from unittest.mock import Mock import numpy as np import numpy.testing as npt -from AnyQt.QtCore import Qt +from AnyQt.QtCore import QItemSelection from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable from Orange.widgets.tests.base import WidgetTest from Orange.widgets.tests.utils import simulate -from Orange.widgets.visualize.owscatterplot import OWScatterPlot from Orange.widgets.widget import AttributeList -from orangecontrib.prototypes.widgets.owinteractions import \ - OWInteractions, Heuristic, HeuristicType, InteractionRank + +from orangecontrib.prototypes.widgets.owinteractions import OWInteractions, Heuristic from orangecontrib.prototypes.interactions import InteractionScorer class TestOWInteractions(WidgetTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.data = Table("iris") - cls.disc_data = Table("zoo") - - def setUp(self): - self.widget = self.create_widget(OWInteractions) - - def test_input_data(self): - """Check interaction table""" - self.send_signal(self.widget.Inputs.data, None) - self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 4) - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 0) - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - n_attrs = len(self.data.domain.attributes) - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 4) - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), n_attrs*(n_attrs-1)/2) - - def test_input_data_one_feature(self): - """Check interaction table for dataset with one attribute""" - self.send_signal(self.widget.Inputs.data, self.data[:, [0, 4]]) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 4) - self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) - self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Warning.not_enough_vars.is_shown()) - - def test_data_no_class(self): - """Check interaction table for dataset without class variable""" - self.send_signal(self.widget.Inputs.data, self.data[:, :-1]) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 4) - self.assertTrue(self.widget.Warning.no_class_var.is_shown()) - - def test_input_data_one_instance(self): - """Check interaction table for dataset with one instance""" - self.send_signal(self.widget.Inputs.data, self.data[:1]) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 4) - self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) - self.assertTrue(self.widget.Warning.not_enough_inst.is_shown()) - self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Warning.not_enough_inst.is_shown()) - - def test_input_data_with_constant_features(self): - """Check interaction table for dataset with constant columns""" - np.random.seed(0) - x = np.random.randint(3, size=(4, 3)).astype(float) - x[:, 2] = 1 - y = np.random.randint(3, size=4).astype(float) - - domain_disc = Domain([DiscreteVariable(str(i), ["a", "b", "c"]) for i in range(3)], DiscreteVariable("cls")) - domain_cont = Domain([ContinuousVariable(str(i)) for i in range(3)], DiscreteVariable("cls")) - - self.send_signal(self.widget.Inputs.data, Table(domain_disc, x, y)) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 3) - self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) - - self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 1) - self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) - - x = np.ones((4, 3), dtype=float) - self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.columnCount(), 4) - self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) - self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) - - self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) - - def test_output_features(self): - """Check features on output""" - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - self.process_events() - features = self.get_output(self.widget.Outputs.features) - self.assertIsInstance(features, AttributeList) - self.assertEqual(len(features), 2) - - def test_output_interactions(self): - """Check interaction table on output""" - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - n_attrs = len(self.data.domain.attributes) - self.process_events() - interactions = self.get_output(self.widget.Outputs.interactions) - self.assertIsInstance(interactions, Table) - self.assertEqual(len(interactions), n_attrs*(n_attrs-1)/2) - self.assertEqual(len(interactions.domain.metas), 2) - self.assertListEqual(["Interaction"], [m.name for m in interactions.domain.attributes]) - - def test_input_changed(self): - """Check whether changing input emits commit""" - self.widget.commit = Mock() - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - self.process_events() - self.widget.commit.assert_called_once() - - np.random.seed(0) - x = np.random.randint(3, size=(4, 3)).astype(float) - y = np.random.randint(3, size=4).astype(float) - domain = Domain([DiscreteVariable(str(i), ["a", "b", "c"]) for i in range(3)], DiscreteVariable("cls")) - - self.widget.commit.reset_mock() - self.send_signal(self.widget.Inputs.data, Table(domain, x, y)) - self.wait_until_finished() - self.process_events() - self.widget.commit.assert_called_once() - - def test_saved_selection(self): - """Select row from settings""" - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - self.process_events() - attrs = self.widget.disc_data.domain.attributes - self.widget._vizrank_selection_changed(attrs[1], attrs[3]) - settings = self.widget.settingsHandler.pack_data(self.widget) - - w = self.create_widget(OWInteractions, stored_settings=settings) - self.send_signal(self.widget.Inputs.data, self.data, widget=w) - self.wait_until_finished(w) - self.process_events() - sel_row = w.vizrank.rank_table.selectionModel().selectedRows()[0].row() - self.assertEqual(sel_row, 1) - - def test_scatterplot_input_features(self): - """Check if attributes have been set after sent to scatterplot""" - self.send_signal(self.widget.Inputs.data, self.data) - spw = self.create_widget(OWScatterPlot) - attrs = self.widget.disc_data.domain.attributes - self.widget._vizrank_selection_changed(attrs[2], attrs[3]) - features = self.get_output(self.widget.Outputs.features) - self.send_signal(self.widget.Inputs.data, self.data, widget=spw) - self.send_signal(spw.Inputs.features, features, widget=spw) - self.assertIs(spw.attr_x, self.data.domain[2]) - self.assertIs(spw.attr_y, self.data.domain[3]) - - @patch("orangecontrib.prototypes.widgets.owinteractions.SIZE_LIMIT", 2000) - def test_heuristic_type(self): - h_type = self.widget.controls.heuristic_type - self.send_signal(self.widget.Inputs.data, self.disc_data) - self.wait_until_finished() - self.process_events() - infogain = list(self.widget.vizrank.heuristic.get_states(None)) - - simulate.combobox_activate_item(h_type, "Random Search") - self.wait_until_finished() - self.process_events() - random = list(self.widget.vizrank.heuristic.get_states(None)) - - self.assertFalse(infogain == random, msg="Double check results, there is a 1 in 15! chance heuristics are equal.") - - def test_feature_combo(self): - """Check content of feature selection combobox""" - feature_combo = self.widget.controls.feature - self.send_signal(self.widget.Inputs.data, self.data) - self.assertEqual(len(feature_combo.model()), 5) - - self.wait_until_stop_blocking() - self.send_signal(self.widget.Inputs.data, self.disc_data) - self.assertEqual(len(feature_combo.model()), 17) - - def test_select_feature(self): - """Check feature selection""" - feature_combo = self.widget.controls.feature - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 6) - self.assertListEqual( - [a.name for a in self.get_output(self.widget.Outputs.features)], - ["sepal length", "sepal width"] - ) - - simulate.combobox_activate_index(feature_combo, 3) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 3) - self.assertListEqual( - [a.name for a in self.get_output(self.widget.Outputs.features)], - ["petal length", "sepal width"] - ) - - simulate.combobox_activate_index(feature_combo, 0) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 6) - self.assertListEqual( - [a.name for a in self.get_output(self.widget.Outputs.features)], - ["petal length", "sepal width"] - ) - - @patch("orangecontrib.prototypes.widgets.owinteractions.SIZE_LIMIT", 2000) - def test_vizrank_use_heuristic(self): - """Check heuristic use""" - self.send_signal(self.widget.Inputs.data, self.data) - self.wait_until_finished() - self.process_events() - self.assertTrue(self.widget.vizrank.use_heuristic) - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 6) - - @patch("orangecontrib.prototypes.widgets.owinteractions.SIZE_LIMIT", 2000) - def test_select_feature_against_heuristic(self): - """Check heuristic use when feature selected""" - feature_combo = self.widget.controls.feature - self.send_signal(self.widget.Inputs.data, self.data) - simulate.combobox_activate_index(feature_combo, 2) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.vizrank.rank_model.rowCount(), 3) - self.assertEqual(self.widget.vizrank.heuristic, None) - - -class TestInteractionRank(WidgetTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]]) - y = np.array([0, 1, 1, 1]) - domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) - cls.data = Table(domain, x, y) - cls.attrs = cls.data.domain.attributes - - def setUp(self): - self.vizrank = InteractionRank(None) - self.vizrank.attrs = self.attrs - - def test_row_for_state(self): - """Check row calculation""" - row = self.vizrank.row_for_state((0.1511, 0.3837, 0.1511), (0, 1)) - self.assertEqual(row[0].data(Qt.DisplayRole), "+15.1%") - self.assertEqual(row[0].data(InteractionRank.IntRole), 0.1511) - self.assertListEqual(row[0].data(InteractionRank.GainRole), [0.3837, 0.1511]) - self.assertEqual(row[1].data(Qt.DisplayRole), "68.6%") - self.assertEqual(row[2].data(Qt.DisplayRole), self.attrs[0].name) - self.assertEqual(row[3].data(Qt.DisplayRole), self.attrs[1].name) + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.iris = Table("iris") # continuous data + cls.zoo = Table("zoo") # discrete data + + def setUp(self): + self.widget = self.create_widget(OWInteractions) + + def test_input_data(self): + """Check table on input data""" + self.send_signal(self.widget.Inputs.data, None) + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertEqual(self.widget.model.rowCount(), 0) + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 4) + self.assertEqual(self.widget.model.rowCount(), 6) + + def test_input_data_one_feature(self): + """Check table on input data with single attribute""" + self.send_signal(self.widget.Inputs.data, self.iris[:, [0, 4]]) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Warning.not_enough_vars.is_shown()) + + def test_input_data_no_target(self): + """Check table on input data without target""" + self.send_signal(self.widget.Inputs.data, self.iris[:, :-1]) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertTrue(self.widget.Warning.no_class_var.is_shown()) + + def test_input_data_one_instance(self): + """Check table on input data with single instance""" + self.send_signal(self.widget.Inputs.data, self.iris[:1]) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + self.assertTrue(self.widget.Warning.not_enough_inst.is_shown()) + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Warning.not_enough_inst.is_shown()) + + def test_input_data_constant_features(self): + """Check table on input data with constant columns""" + x = np.array([[0, 2, 1], + [0, 2, 0], + [0, 0, 1], + [0, 1, 2]]) + y = np.array([1, 2, 1, 0]) + labels = ["a", "b", "c"] + domain_disc = Domain([DiscreteVariable(str(i), labels) for i in range(3)], + DiscreteVariable("cls", labels)) + domain_cont = Domain([ContinuousVariable(str(i)) for i in range(3)], + DiscreteVariable("cls", labels)) + + self.send_signal(self.widget.Inputs.data, Table(domain_disc, x, y)) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 3) + self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + + self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 1) + self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) + + x = np.ones((4, 3), dtype=float) + self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.columnCount(), 0) + self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) + self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) + + self.send_signal(self.widget.Inputs.data, None) + self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) + + def test_output_features(self): + """Check output""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + features = self.get_output(self.widget.Outputs.features) + self.assertIsInstance(features, AttributeList) + self.assertEqual(len(features), 2) + + def test_input_changed(self): + """Check commit on input""" + self.widget.commit = Mock() + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.widget.commit.assert_called_once() + + x = np.array([[0, 1, 0], + [1, 1, 2], + [0, 2, 0], + [0, 0, 2]]) + y = np.array([1, 2, 2, 0]) + domain = Domain([DiscreteVariable(str(i), ["a", "b", "c"]) for i in range(3)], + DiscreteVariable("cls")) + + self.widget.commit.reset_mock() + self.send_signal(self.widget.Inputs.data, Table(domain, x, y)) + self.wait_until_finished() + self.process_events() + self.widget.commit.assert_called_once() + + def test_saved_selection(self): + """Check row selection""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + selection = QItemSelection() + selection.select(self.widget.model.index(2, 0), + self.widget.model.index(2, 3)) + self.widget.on_selection_changed(selection) + settings = self.widget.settingsHandler.pack_data(self.widget) + + w = self.create_widget(OWInteractions, stored_settings=settings) + self.send_signal(self.widget.Inputs.data, self.iris, widget=w) + self.wait_until_finished(w) + self.process_events() + sel_row = w.rank_table.selectionModel().selectedRows()[0].row() + self.assertEqual(sel_row, 2) + + def test_feature_combo(self): + """Check feature combobox""" + feature_combo = self.widget.controls.feature + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertEqual(len(feature_combo.model()), 5) + + self.wait_until_stop_blocking() + self.send_signal(self.widget.Inputs.data, self.zoo) + self.assertEqual(len(feature_combo.model()), 17) + + def test_select_feature(self): + """Check feature selection""" + feature_combo = self.widget.controls.feature + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 6) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"sepal width", "sepal length"} + ) + + simulate.combobox_activate_index(feature_combo, 3) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 3) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"petal length", "sepal width"} + ) + + simulate.combobox_activate_index(feature_combo, 0) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 6) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"petal length", "sepal width"} + ) + + def test_send_report(self): + """Check report""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.widget.report_button.click() + self.wait_until_stop_blocking() + self.send_signal(self.widget.Inputs.data, None) + self.widget.report_button.click() + + def test_compute_score(self): + self.widget.scorer = InteractionScorer(self.zoo) + npt.assert_almost_equal(self.widget.compute_score((1, 0)), + [-0.0771, 0.3003, 0.3307], 4) + + def test_row_for_state(self): + row = self.widget.row_for_state((-0.2, 0.2, 0.1), (1, 0)) + self.assertListEqual(row, [-0.2, 0.1, 1, 0]) + + def test_iterate_states(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertListEqual(list(self.widget._iterate_all(None)), + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) + self.assertListEqual(list(self.widget._iterate_all((1, 0))), + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) + self.assertListEqual(list(self.widget._iterate_all((2, 1))), + [(2, 1), (3, 0), (3, 1), (3, 2)]) + self.widget.feature_index = 2 + self.assertListEqual(list(self.widget._iterate_by_feature(None)), + [(2, 0), (2, 1), (2, 3)]) + self.assertListEqual(list(self.widget._iterate_by_feature((2, 0))), + [(2, 0), (2, 1), (2, 3)]) + self.assertListEqual(list(self.widget._iterate_by_feature((2, 3))), + [(2, 3)]) + + def test_state_count(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertEqual(self.widget.state_count(), 6) + self.widget.feature_index = 2 + self.assertEqual(self.widget.state_count(), 3) class TestInteractionScorer(unittest.TestCase): - def test_compute_score(self): - """Check score calculation""" - x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]]) - y = np.array([0, 1, 1, 1]) - domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) - data = Table(domain, x, y) - self.scorer = InteractionScorer(data) - npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4) - npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4) - npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4) - npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4) - - def test_nans(self): - """Check score calculation with sparse data""" - x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]]) - y = np.array([0, 1, 1, 1, 0, 0, 1]) - domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) - data = Table(domain, x, y) - self.scorer = InteractionScorer(data) - npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4) - npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4) - npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4) - npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4) + def test_compute_score(self): + """Check score calculation""" + x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]]) + y = np.array([0, 1, 1, 1]) + domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) + data = Table(domain, x, y) + self.scorer = InteractionScorer(data) + npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4) + npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4) + npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4) + npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4) + + def test_nans(self): + """Check score calculation with nans""" + x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]]) + y = np.array([0, 1, 1, 1, 0, 0, 1]) + domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) + data = Table(domain, x, y) + self.scorer = InteractionScorer(data) + npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4) + npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4) + npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4) + npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4) class TestHeuristic(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.zoo = Table("zoo") - - def test_heuristic(self): - """Check attribute pairs returned by heuristic""" - scorer = InteractionScorer(self.zoo) - heuristic = Heuristic(scorer.information_gain, heuristic_type=HeuristicType.INFOGAIN) - self.assertListEqual( - list(heuristic.get_states(None))[:9], - [(14, 6), (14, 10), (14, 15), (6, 10), (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)] - ) - - states = heuristic.get_states(None) - _ = next(states) - self.assertListEqual( - list(heuristic.get_states(next(states)))[:8], - [(14, 10), (14, 15), (6, 10), (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)] - ) + @classmethod + def setUpClass(cls): + cls.zoo = Table("zoo") + + def test_heuristic(self): + """Check attribute pairs returned by heuristic""" + scorer = InteractionScorer(self.zoo) + heuristic = Heuristic(scorer.information_gain, + type=Heuristic.INFO_GAIN) + self.assertListEqual(list(heuristic.get_states(None))[:9], + [(14, 6), (14, 10), (14, 15), (6, 10), + (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)]) + + states = heuristic.get_states(None) + _ = next(states) + self.assertListEqual(list(heuristic.get_states(next(states)))[:8], + [(14, 10), (14, 15), (6, 10), (14, 5), + (6, 15), (14, 11), (6, 5), (10, 15)]) if __name__ == "__main__": - unittest.main() + unittest.main() diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py deleted file mode 100644 index ef98a82c..00000000 --- a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py +++ /dev/null @@ -1,281 +0,0 @@ -import unittest -from unittest.mock import Mock - -import numpy as np -import numpy.testing as npt - -from AnyQt.QtCore import QItemSelection - -from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable -from Orange.widgets.tests.base import WidgetTest -from Orange.widgets.tests.utils import simulate -from Orange.widgets.widget import AttributeList - -from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions, Heuristic -from orangecontrib.prototypes.interactions import InteractionScorer - - -class TestOWInteractions(WidgetTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.iris = Table("iris") # continuous data - cls.zoo = Table("zoo") # discrete data - - def setUp(self): - self.widget = self.create_widget(OWInteractions) - - def test_input_data(self): - """Check table on input data""" - self.send_signal(self.widget.Inputs.data, None) - self.assertEqual(self.widget.model.columnCount(), 0) - self.assertEqual(self.widget.model.rowCount(), 0) - self.send_signal(self.widget.Inputs.data, self.iris) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.columnCount(), 4) - self.assertEqual(self.widget.model.rowCount(), 6) - - def test_input_data_one_feature(self): - """Check table on input data with single attribute""" - self.send_signal(self.widget.Inputs.data, self.iris[:, [0, 4]]) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.columnCount(), 0) - self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) - self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Warning.not_enough_vars.is_shown()) - - def test_input_data_no_target(self): - """Check table on input data without target""" - self.send_signal(self.widget.Inputs.data, self.iris[:, :-1]) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.columnCount(), 0) - self.assertTrue(self.widget.Warning.no_class_var.is_shown()) - - def test_input_data_one_instance(self): - """Check table on input data with single instance""" - self.send_signal(self.widget.Inputs.data, self.iris[:1]) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.columnCount(), 0) - self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) - self.assertTrue(self.widget.Warning.not_enough_inst.is_shown()) - self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Warning.not_enough_inst.is_shown()) - - def test_input_data_constant_features(self): - """Check table on input data with constant columns""" - x = np.array([[0, 2, 1], - [0, 2, 0], - [0, 0, 1], - [0, 1, 2]]) - y = np.array([1, 2, 1, 0]) - labels = ["a", "b", "c"] - domain_disc = Domain([DiscreteVariable(str(i), labels) for i in range(3)], - DiscreteVariable("cls", labels)) - domain_cont = Domain([ContinuousVariable(str(i)) for i in range(3)], - DiscreteVariable("cls", labels)) - - self.send_signal(self.widget.Inputs.data, Table(domain_disc, x, y)) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.rowCount(), 3) - self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) - - self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.rowCount(), 1) - self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) - - x = np.ones((4, 3), dtype=float) - self.send_signal(self.widget.Inputs.data, Table(domain_cont, x, y)) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.columnCount(), 0) - self.assertTrue(self.widget.Warning.not_enough_vars.is_shown()) - self.assertTrue(self.widget.Information.removed_cons_feat.is_shown()) - - self.send_signal(self.widget.Inputs.data, None) - self.assertFalse(self.widget.Information.removed_cons_feat.is_shown()) - - def test_output_features(self): - """Check output""" - self.send_signal(self.widget.Inputs.data, self.iris) - self.wait_until_finished() - self.process_events() - features = self.get_output(self.widget.Outputs.features) - self.assertIsInstance(features, AttributeList) - self.assertEqual(len(features), 2) - - def test_input_changed(self): - """Check commit on input""" - self.widget.commit = Mock() - self.send_signal(self.widget.Inputs.data, self.iris) - self.wait_until_finished() - self.process_events() - self.widget.commit.assert_called_once() - - x = np.array([[0, 1, 0], - [1, 1, 2], - [0, 2, 0], - [0, 0, 2]]) - y = np.array([1, 2, 2, 0]) - domain = Domain([DiscreteVariable(str(i), ["a", "b", "c"]) for i in range(3)], - DiscreteVariable("cls")) - - self.widget.commit.reset_mock() - self.send_signal(self.widget.Inputs.data, Table(domain, x, y)) - self.wait_until_finished() - self.process_events() - self.widget.commit.assert_called_once() - - def test_saved_selection(self): - """Check row selection""" - self.send_signal(self.widget.Inputs.data, self.iris) - self.wait_until_finished() - self.process_events() - selection = QItemSelection() - selection.select(self.widget.model.index(2, 0), - self.widget.model.index(2, 3)) - self.widget.on_selection_changed(selection) - settings = self.widget.settingsHandler.pack_data(self.widget) - - w = self.create_widget(OWInteractions, stored_settings=settings) - self.send_signal(self.widget.Inputs.data, self.iris, widget=w) - self.wait_until_finished(w) - self.process_events() - sel_row = w.rank_table.selectionModel().selectedRows()[0].row() - self.assertEqual(sel_row, 2) - - def test_feature_combo(self): - """Check feature combobox""" - feature_combo = self.widget.controls.feature - self.send_signal(self.widget.Inputs.data, self.iris) - self.assertEqual(len(feature_combo.model()), 5) - - self.wait_until_stop_blocking() - self.send_signal(self.widget.Inputs.data, self.zoo) - self.assertEqual(len(feature_combo.model()), 17) - - def test_select_feature(self): - """Check feature selection""" - feature_combo = self.widget.controls.feature - self.send_signal(self.widget.Inputs.data, self.iris) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.rowCount(), 6) - self.assertSetEqual( - {a.name for a in self.get_output(self.widget.Outputs.features)}, - {"sepal width", "sepal length"} - ) - - simulate.combobox_activate_index(feature_combo, 3) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.rowCount(), 3) - self.assertSetEqual( - {a.name for a in self.get_output(self.widget.Outputs.features)}, - {"petal length", "sepal width"} - ) - - simulate.combobox_activate_index(feature_combo, 0) - self.wait_until_finished() - self.process_events() - self.assertEqual(self.widget.model.rowCount(), 6) - self.assertSetEqual( - {a.name for a in self.get_output(self.widget.Outputs.features)}, - {"petal length", "sepal width"} - ) - - def test_send_report(self): - """Check report""" - self.send_signal(self.widget.Inputs.data, self.iris) - self.widget.report_button.click() - self.wait_until_stop_blocking() - self.send_signal(self.widget.Inputs.data, None) - self.widget.report_button.click() - - def test_compute_score(self): - self.widget.scorer = InteractionScorer(self.zoo) - npt.assert_almost_equal(self.widget.compute_score((1, 0)), - [-0.0771, 0.3003, 0.3307], 4) - - def test_row_for_state(self): - row = self.widget.row_for_state((-0.2, 0.2, 0.1), (1, 0)) - self.assertListEqual(row, [-0.2, 0.1, 1, 0]) - - def test_iterate_states(self): - self.send_signal(self.widget.Inputs.data, self.iris) - self.assertListEqual(list(self.widget._iterate_all(None)), - [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) - self.assertListEqual(list(self.widget._iterate_all((1, 0))), - [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) - self.assertListEqual(list(self.widget._iterate_all((2, 1))), - [(2, 1), (3, 0), (3, 1), (3, 2)]) - self.widget.feature_index = 2 - self.assertListEqual(list(self.widget._iterate_by_feature(None)), - [(2, 0), (2, 1), (2, 3)]) - self.assertListEqual(list(self.widget._iterate_by_feature((2, 0))), - [(2, 0), (2, 1), (2, 3)]) - self.assertListEqual(list(self.widget._iterate_by_feature((2, 3))), - [(2, 3)]) - - def test_state_count(self): - self.send_signal(self.widget.Inputs.data, self.iris) - self.assertEqual(self.widget.state_count(), 6) - self.widget.feature_index = 2 - self.assertEqual(self.widget.state_count(), 3) - - -class TestInteractionScorer(unittest.TestCase): - def test_compute_score(self): - """Check score calculation""" - x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]]) - y = np.array([0, 1, 1, 1]) - domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) - data = Table(domain, x, y) - self.scorer = InteractionScorer(data) - npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4) - npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4) - npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4) - npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4) - - def test_nans(self): - """Check score calculation with nans""" - x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]]) - y = np.array([0, 1, 1, 1, 0, 0, 1]) - domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) - data = Table(domain, x, y) - self.scorer = InteractionScorer(data) - npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4) - npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4) - npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4) - npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4) - - -class TestHeuristic(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.zoo = Table("zoo") - - def test_heuristic(self): - """Check attribute pairs returned by heuristic""" - scorer = InteractionScorer(self.zoo) - heuristic = Heuristic(scorer.information_gain, - type=Heuristic.INFO_GAIN) - self.assertListEqual(list(heuristic.get_states(None))[:9], - [(14, 6), (14, 10), (14, 15), (6, 10), - (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)]) - - states = heuristic.get_states(None) - _ = next(states) - self.assertListEqual(list(heuristic.get_states(next(states)))[:8], - [(14, 10), (14, 15), (6, 10), (14, 5), - (6, 15), (14, 11), (6, 5), (10, 15)]) - - -if __name__ == "__main__": - unittest.main() From 3e8d853fe749beaf6f9243f8d60d200fc54d1530 Mon Sep 17 00:00:00 2001 From: noahnovsak Date: Wed, 29 Mar 2023 10:41:08 +0200 Subject: [PATCH 6/6] cleanup --- orangecontrib/prototypes/interactions.py | 18 +-- orangecontrib/prototypes/ranktablemodel.py | 22 ++-- .../prototypes/widgets/owinteractions.py | 107 +++++++++--------- .../widgets/tests/test_owinteractions.py | 2 +- 4 files changed, 73 insertions(+), 76 deletions(-) diff --git a/orangecontrib/prototypes/interactions.py b/orangecontrib/prototypes/interactions.py index d1094534..830c0983 100644 --- a/orangecontrib/prototypes/interactions.py +++ b/orangecontrib/prototypes/interactions.py @@ -1,7 +1,7 @@ import numpy as np -def get_row_ids(ar): +def hash_rows(ar): row_ids = ar[:, 0].copy() # Assuming the data has been discretized into fewer # than 10000 bins and that `ar` has up to 3 columns, @@ -30,7 +30,7 @@ def distribution(ar): # implementation doesn't release the GIL. The simplest # solution seems to be generating unique numbers/ids # based on the contents of each row. - ar = get_row_ids(ar) + ar = hash_rows(ar) _, counts = np.unique(ar, return_counts=True) return counts / ar.shape[0] @@ -47,9 +47,9 @@ def __init__(self, data): self.class_entropy = 0 self.information_gain = np.zeros(data.X.shape[1]) - self.precompute() + self.preprocess() - def precompute(self): + def preprocess(self): """ Precompute information gain of each attribute to speed up computation and to create heuristic. @@ -68,12 +68,12 @@ def precompute(self): - entropy(np.column_stack((self.data.X[:, attr], self.data.Y))) def __call__(self, attr1, attr2): - attrs = np.column_stack((self.data.X[:, attr1], self.data.X[:, attr2])) + attrs = self.data.X[:, (attr1, attr2)] return self.class_entropy \ - - self.information_gain[attr1] \ - - self.information_gain[attr2] \ - + entropy(attrs) \ - - entropy(np.column_stack((attrs, self.data.Y))) + - self.information_gain[attr1] \ + - self.information_gain[attr2] \ + + entropy(attrs) \ + - entropy(np.column_stack((attrs, self.data.Y))) def normalize(self, score): return score / self.class_entropy diff --git a/orangecontrib/prototypes/ranktablemodel.py b/orangecontrib/prototypes/ranktablemodel.py index 3e910918..0f7e3c2a 100644 --- a/orangecontrib/prototypes/ranktablemodel.py +++ b/orangecontrib/prototypes/ranktablemodel.py @@ -39,7 +39,8 @@ def __init__(self, *args, **kwargs): # ``rowCount`` returns the lowest of `_rows` and `_max_view_rows`: # how large the model/view thinks it is - def sortInd(self): + @property + def __sortInd(self): return self._AbstractSortTableModel__sortInd def sortColumnData(self, column): @@ -49,7 +50,7 @@ def extendSortFrom(self, sorted_rows: int): data = self.sortColumnData(self.sortColumn()) new_ind = np.arange(sorted_rows, self._rows) order = 1 if self.sortOrder() == Qt.AscendingOrder else -1 - sorter = self.sortInd()[::order] + sorter = self.__sortInd[::order] new_sorter = np.argsort(data[sorted_rows:]) loc = np.searchsorted(data[:sorted_rows], data[sorted_rows:][new_sorter], @@ -88,18 +89,21 @@ def clear(self): self.resetSorting() self.endResetModel() - def append(self, rows: list[list[float]]): + def extend(self, rows: list[list[float]]): if not isinstance(self._data, np.ndarray): - return self.initialize(rows) + self.initialize(rows) + return n_rows = len(rows) if n_rows == 0: return + n_data = len(self._data) insert = self._rows < self._max_view_rows if insert: - self.beginInsertRows(QModelIndex(), self._rows, min(self._max_view_rows, self._rows + n_rows) - 1) + self.beginInsertRows(QModelIndex(), self._rows, + min(self._max_view_rows, self._rows + n_rows) - 1) if self._rows + n_rows >= n_data: n_data = min(max(n_data + n_rows, 2 * n_data), self._max_data_rows) @@ -126,11 +130,9 @@ class RankModel(ArrayTableModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.domain = None # type: Domain self.domain_model = DomainModel(DomainModel.ATTRIBUTES) def set_domain(self, domain: Domain): - self.domain = domain self.domain_model.set_domain(domain) n_attrs = len(domain.attributes) self._max_data_rows = n_attrs * (n_attrs - 1) // 2 @@ -143,16 +145,16 @@ def resetSorting(self): def data(self, index: QModelIndex, role=Qt.DisplayRole): if not index.isValid(): - return + return None column = index.column() - if column >= self.columnCount() - 2 and role != Qt.EditRole: + # use domain model for all data (except editrole) in last two columns try: row = self.mapToSourceRows(index.row()) value = self.domain_model.index(int(self._data[row, column])) return self.domain_model.data(value, role) except IndexError: - return + return None return super().data(index, role) diff --git a/orangecontrib/prototypes/widgets/owinteractions.py b/orangecontrib/prototypes/widgets/owinteractions.py index 1f26f603..7bd76feb 100644 --- a/orangecontrib/prototypes/widgets/owinteractions.py +++ b/orangecontrib/prototypes/widgets/owinteractions.py @@ -9,7 +9,10 @@ from AnyQt.QtWidgets import QTableView, QHeaderView, \ QStyleOptionViewItem, QApplication, QStyle, QLineEdit -from Orange.data import Table, Variable +from orangecontrib.prototypes.ranktablemodel import RankModel +from orangecontrib.prototypes.interactions import InteractionScorer + +from Orange.data import Table, Domain, Variable from Orange.preprocess import Discretize, Remove from Orange.widgets import gui from Orange.widgets.widget import OWWidget, AttributeList, Msg @@ -19,31 +22,23 @@ from Orange.widgets.utils.itemmodels import DomainModel from Orange.widgets.settings import Setting, ContextSetting, DomainContextHandler -from orangecontrib.prototypes.ranktablemodel import RankModel -from orangecontrib.prototypes.interactions import InteractionScorer - class ModelQueue: - """ - Another queueing object, similar to ``queue.Queue``. - The main difference is that ``get()`` returns all its - contents at the same time, instead of one by one. - """ def __init__(self): - self.lock = Lock() - self.model = [] - self.state = None + self.mutex = Lock() + self.queue = [] + self.latest_state = None def put(self, row, state): - with self.lock: - self.model.append(row) - self.state = state + with self.mutex: + self.queue.append(row) + self.latest_state = state def get(self): - with self.lock: - model, self.model = self.model, [] - state, self.state = self.state, None - return model, state + with self.mutex: + queue, self.queue = self.queue, [] + state, self.latest_state = self.latest_state, None + return queue, state def run(compute_score: Callable, row_for_state: Callable, @@ -52,8 +47,8 @@ def run(compute_score: Callable, row_for_state: Callable, """ Replaces ``run_vizrank``, with some minor adjustments. - ``ModelQueue`` replaces ``queue.Queue`` - - `row_for_state` parameter added - - `scores` parameter removed + - `row_for_state` can be called here, assuming we are not adding `Qt` objects to the model + - `scores` removed """ task.set_status("Getting combinations...") task.set_progress_value(0.1) @@ -103,16 +98,16 @@ def reset_flag(): class Heuristic: RANDOM, INFO_GAIN = 0, 1 - type = {RANDOM: "Random Search", - INFO_GAIN: "Information Gain Heuristic"} + mode = {RANDOM: "Random Search", + INFO_GAIN: "Low Information Gain First"} - def __init__(self, weights, type=None): + def __init__(self, weights, mode=RANDOM): self.n_attributes = len(weights) self.attributes = np.arange(self.n_attributes) - if type == self.RANDOM: - np.random.shuffle(self.attributes) - if type == self.INFO_GAIN: + if mode == Heuristic.INFO_GAIN: self.attributes = self.attributes[np.argsort(weights)] + else: + np.random.shuffle(self.attributes) def generate_states(self): # prioritize two mid ranked attributes over highest first @@ -196,7 +191,6 @@ class OWInteractions(OWWidget, ConcurrentWidgetMixin): name = "Interactions" description = "Compute all pairwise attribute interactions." icon = "icons/Interactions.svg" - category = "Unsupervised" class Inputs: data = Input("Data", Table) @@ -208,8 +202,8 @@ class Outputs: selection = ContextSetting([]) feature: Variable feature = ContextSetting(None) - heuristic_type: int - heuristic_type = Setting(0) + heuristic_mode: int + heuristic_mode = Setting(0) want_main_area = False want_control_area = True @@ -230,16 +224,16 @@ def __init__(self): self.saved_state = None self.progress = 0 - self.data = None # type: Table - self.pp_data = None # type: Table + self.original_domain: Domain = ... + self.data: Table = ... self.n_attrs = 0 self.scorer = None self.heuristic = None self.feature_index = None - gui.comboBox(self.controlArea, self, "heuristic_type", - items=Heuristic.type.values(), + gui.comboBox(self.controlArea, self, "heuristic_mode", + items=Heuristic.mode.values(), callback=self.on_heuristic_combo_changed,) self.feature_model = DomainModel(order=DomainModel.ATTRIBUTES, @@ -279,8 +273,8 @@ def set_data(self, data): self.closeContext() self.clear_messages() self.selection = {} - self.data = data - self.pp_data = None + self.original_domain = data and data.domain + self.data = None self.n_attrs = 0 if data is not None: if len(data) < 2: @@ -289,20 +283,20 @@ def set_data(self, data): self.Warning.no_class_var() else: remover = Remove(Remove.RemoveConstant) - pp_data = Discretize()(remover(data)) + data = Discretize()(remover(data)) if remover.attr_results["removed"]: self.Information.removed_cons_feat() - if len(pp_data.domain.attributes) < 2: + if len(data.domain.attributes) < 2: self.Warning.not_enough_vars() else: - self.pp_data = pp_data - self.n_attrs = len(pp_data.domain.attributes) - self.scorer = InteractionScorer(pp_data) - self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) - self.model.set_domain(pp_data.domain) + self.data = data + self.n_attrs = len(data.domain.attributes) + self.scorer = InteractionScorer(data) + self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_mode) + self.model.set_domain(data.domain) self.proxy.scorer = self.scorer - self.feature_model.set_domain(self.pp_data and self.pp_data.domain) - self.openContext(self.pp_data) + self.feature_model.set_domain(self.data and self.data.domain) + self.openContext(self.data) self.initialize() def initialize(self): @@ -316,17 +310,17 @@ def initialize(self): self.model.clear() self.filter.setText("") self.button.setText("Start") - self.button.setEnabled(self.pp_data is not None) - if self.pp_data is not None: + self.button.setEnabled(self.data is not None) + if self.data is not None: self.toggle() def commit(self): - if self.data is None: + if self.original_domain is None: self.Outputs.features.send(None) return self.Outputs.features.send(AttributeList( - [self.data.domain[attr] for attr in self.selection])) + [self.original_domain[attr] for attr in self.selection])) def toggle(self): self.keep_running = not self.keep_running @@ -373,18 +367,19 @@ def on_filter_changed(self, text): self.proxy.setFilterFixedString(text) def on_feature_combo_changed(self): - self.feature_index = self.feature and self.pp_data.domain.index(self.feature) + self.feature_index = self.feature and self.data.domain.index(self.feature) self.initialize() def on_heuristic_combo_changed(self): - if self.pp_data is not None: - self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_type) + if self.data is not None: + self.heuristic = Heuristic(self.scorer.information_gain, self.heuristic_mode) self.initialize() def compute_score(self, state): - scores = (self.scorer(*state), - self.scorer.information_gain[state[0]], - self.scorer.information_gain[state[1]]) + attr1, attr2 = state + scores = (self.scorer(attr1, attr2), + self.scorer.information_gain[attr1], + self.scorer.information_gain[attr2]) return tuple(self.scorer.normalize(score) for score in scores) @staticmethod @@ -420,7 +415,7 @@ def on_partial_result(self, result): add_to_model, latest_state = result if add_to_model: self.saved_state = latest_state - self.model.append(add_to_model) + self.model.extend(add_to_model) self.progress = len(self.model) self.progressBarSet(self.progress * 100 // self.state_count()) diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions.py index b5ebd8fa..c99d9811 100644 --- a/orangecontrib/prototypes/widgets/tests/test_owinteractions.py +++ b/orangecontrib/prototypes/widgets/tests/test_owinteractions.py @@ -265,7 +265,7 @@ def test_heuristic(self): """Check attribute pairs returned by heuristic""" scorer = InteractionScorer(self.zoo) heuristic = Heuristic(scorer.information_gain, - type=Heuristic.INFO_GAIN) + Heuristic.INFO_GAIN) self.assertListEqual(list(heuristic.get_states(None))[:9], [(14, 6), (14, 10), (14, 15), (6, 10), (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)])