diff --git a/orangecontrib/prototypes/ranktablemodel.py b/orangecontrib/prototypes/ranktablemodel.py index ffbb3984..3e910918 100644 --- a/orangecontrib/prototypes/ranktablemodel.py +++ b/orangecontrib/prototypes/ranktablemodel.py @@ -19,8 +19,9 @@ class ArrayTableModel(PyTableModel): Other, unlisted methods aren't guaranteed to work and should be used with care. Also requires access to private members of ``AbstractSortTableModel`` directly; - ``__sortInd`` is used for sorting and potential filtering, and ``__init__`` because the - parent implementation immediately wraps a list, which this model does not use. + ``__sortInd`` is needed to append new unsorted data, and ``__init__`` is used + because the parent implementation immediately wraps a list, which this model + does not have. """ def __init__(self, *args, **kwargs): super(PyTableModel, self).__init__(*args, **kwargs) @@ -31,8 +32,12 @@ def __init__(self, *args, **kwargs): self._data = None # type: np.ndarray self._columns = 0 - self._rows = 0 - self._max_display_rows = self._max_data_rows = MAX_ROWS + self._rows = 0 # current number of rows containing data + self._max_view_rows = MAX_ROWS # maximum number of rows the model/view will display + self._max_data_rows = MAX_ROWS # maximum allowed size for the `_data` array + # ``__len__`` returns _rows: amount of existing data in the model + # ``rowCount`` returns the lowest of `_rows` and `_max_view_rows`: + # how large the model/view thinks it is def sortInd(self): return self._AbstractSortTableModel__sortInd @@ -53,7 +58,7 @@ def extendSortFrom(self, sorted_rows: int): self.setSortIndices(indices) def rowCount(self, parent=QModelIndex()): - return 0 if parent.isValid() else min(self._rows, self._max_display_rows) + return 0 if parent.isValid() else min(self._rows, self._max_view_rows) def columnCount(self, parent=QModelIndex()): return 0 if parent.isValid() else self._columns @@ -91,10 +96,10 @@ def append(self, rows: list[list[float]]): if n_rows == 0: return n_data = len(self._data) - insert = self._rows < self._max_display_rows + insert = self._rows < self._max_view_rows if insert: - self.beginInsertRows(QModelIndex(), self._rows, min(self._max_display_rows, self._rows + n_rows) - 1) + self.beginInsertRows(QModelIndex(), self._rows, min(self._max_view_rows, self._rows + n_rows) - 1) if self._rows + n_rows >= n_data: n_data = min(max(n_data + n_rows, 2 * n_data), self._max_data_rows) diff --git a/orangecontrib/prototypes/widgets/owinteractions_new.py b/orangecontrib/prototypes/widgets/owinteractions_new.py index 253bda20..9729a7d3 100644 --- a/orangecontrib/prototypes/widgets/owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/owinteractions_new.py @@ -278,7 +278,7 @@ def __init__(self): def set_data(self, data): self.closeContext() self.clear_messages() - self.selection = [] + self.selection = {} self.data = data self.pp_data = None self.n_attrs = 0 @@ -333,24 +333,40 @@ def toggle(self): if not self.keep_running: self.button.setText("Pause") self.button.repaint() - self.progressBarInit() self.filter.setEnabled(False) + self.progressBarInit() self.start(run, self.compute_score, self.row_for_state, self.iterate_states, self.saved_state, self.progress, self.state_count()) else: self.button.setText("Continue") self.button.repaint() - self.cancel() - self.progressBarFinished() self.filter.setEnabled(True) + self.cancel() + self._stopped() + + def _stopped(self): + self.progressBarFinished() + self._select_default() + + def _select_default(self): + n_rows = self.model.rowCount() + if not n_rows: + return + + if self.selection: + for i in range(n_rows): + names = {self.model.data(self.model.index(i, 2)), + self.model.data(self.model.index(i, 3))} + if names == self.selection: + self.rank_table.selectRow(i) + break - def _select_first_if_none(self): if not self.rank_table.selectedIndexes(): self.rank_table.selectRow(0) def on_selection_changed(self, selected): - self.selection = [self.model.data(ind) for ind in selected.indexes()[-2:]] + self.selection = {self.model.data(ind) for ind in selected.indexes()[-2:]} self.commit() def on_filter_changed(self, text): @@ -396,9 +412,9 @@ def _iterate_by_feature(self, initial_state): yield self.feature_index, j def state_count(self): - if self.feature is None: + if self.feature_index is None: return self.n_attrs * (self.n_attrs - 1) // 2 - return self.n_attrs + return self.n_attrs - 1 def on_partial_result(self, result): add_to_model, latest_state = result @@ -412,7 +428,9 @@ def on_done(self, result): self.button.setText("Finished") self.button.setEnabled(False) self.filter.setEnabled(True) - self._select_first_if_none() + self.keep_running = False + self.saved_state = None + self._stopped() def send_report(self): self.report_table("Interactions", self.rank_table) diff --git a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py index 95c66db4..ef98a82c 100644 --- a/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py +++ b/orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py @@ -1,12 +1,18 @@ +import unittest from unittest.mock import Mock import numpy as np +import numpy.testing as npt + +from AnyQt.QtCore import QItemSelection from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.tests.utils import simulate from Orange.widgets.widget import AttributeList -from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions +from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions, Heuristic +from orangecontrib.prototypes.interactions import InteractionScorer class TestOWInteractions(WidgetTest): @@ -126,6 +132,24 @@ def test_input_changed(self): self.process_events() self.widget.commit.assert_called_once() + def test_saved_selection(self): + """Check row selection""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + selection = QItemSelection() + selection.select(self.widget.model.index(2, 0), + self.widget.model.index(2, 3)) + self.widget.on_selection_changed(selection) + settings = self.widget.settingsHandler.pack_data(self.widget) + + w = self.create_widget(OWInteractions, stored_settings=settings) + self.send_signal(self.widget.Inputs.data, self.iris, widget=w) + self.wait_until_finished(w) + self.process_events() + sel_row = w.rank_table.selectionModel().selectedRows()[0].row() + self.assertEqual(sel_row, 2) + def test_feature_combo(self): """Check feature combobox""" feature_combo = self.widget.controls.feature @@ -135,3 +159,123 @@ def test_feature_combo(self): self.wait_until_stop_blocking() self.send_signal(self.widget.Inputs.data, self.zoo) self.assertEqual(len(feature_combo.model()), 17) + + def test_select_feature(self): + """Check feature selection""" + feature_combo = self.widget.controls.feature + self.send_signal(self.widget.Inputs.data, self.iris) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 6) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"sepal width", "sepal length"} + ) + + simulate.combobox_activate_index(feature_combo, 3) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 3) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"petal length", "sepal width"} + ) + + simulate.combobox_activate_index(feature_combo, 0) + self.wait_until_finished() + self.process_events() + self.assertEqual(self.widget.model.rowCount(), 6) + self.assertSetEqual( + {a.name for a in self.get_output(self.widget.Outputs.features)}, + {"petal length", "sepal width"} + ) + + def test_send_report(self): + """Check report""" + self.send_signal(self.widget.Inputs.data, self.iris) + self.widget.report_button.click() + self.wait_until_stop_blocking() + self.send_signal(self.widget.Inputs.data, None) + self.widget.report_button.click() + + def test_compute_score(self): + self.widget.scorer = InteractionScorer(self.zoo) + npt.assert_almost_equal(self.widget.compute_score((1, 0)), + [-0.0771, 0.3003, 0.3307], 4) + + def test_row_for_state(self): + row = self.widget.row_for_state((-0.2, 0.2, 0.1), (1, 0)) + self.assertListEqual(row, [-0.2, 0.1, 1, 0]) + + def test_iterate_states(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertListEqual(list(self.widget._iterate_all(None)), + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) + self.assertListEqual(list(self.widget._iterate_all((1, 0))), + [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)]) + self.assertListEqual(list(self.widget._iterate_all((2, 1))), + [(2, 1), (3, 0), (3, 1), (3, 2)]) + self.widget.feature_index = 2 + self.assertListEqual(list(self.widget._iterate_by_feature(None)), + [(2, 0), (2, 1), (2, 3)]) + self.assertListEqual(list(self.widget._iterate_by_feature((2, 0))), + [(2, 0), (2, 1), (2, 3)]) + self.assertListEqual(list(self.widget._iterate_by_feature((2, 3))), + [(2, 3)]) + + def test_state_count(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertEqual(self.widget.state_count(), 6) + self.widget.feature_index = 2 + self.assertEqual(self.widget.state_count(), 3) + + +class TestInteractionScorer(unittest.TestCase): + def test_compute_score(self): + """Check score calculation""" + x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]]) + y = np.array([0, 1, 1, 1]) + domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) + data = Table(domain, x, y) + self.scorer = InteractionScorer(data) + npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4) + npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4) + npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4) + npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4) + + def test_nans(self): + """Check score calculation with nans""" + x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]]) + y = np.array([0, 1, 1, 1, 0, 0, 1]) + domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3")) + data = Table(domain, x, y) + self.scorer = InteractionScorer(data) + npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4) + npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4) + npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4) + npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4) + + +class TestHeuristic(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.zoo = Table("zoo") + + def test_heuristic(self): + """Check attribute pairs returned by heuristic""" + scorer = InteractionScorer(self.zoo) + heuristic = Heuristic(scorer.information_gain, + type=Heuristic.INFO_GAIN) + self.assertListEqual(list(heuristic.get_states(None))[:9], + [(14, 6), (14, 10), (14, 15), (6, 10), + (14, 5), (6, 15), (14, 11), (6, 5), (10, 15)]) + + states = heuristic.get_states(None) + _ = next(states) + self.assertListEqual(list(heuristic.get_states(next(states)))[:8], + [(14, 10), (14, 15), (6, 10), (14, 5), + (6, 15), (14, 11), (6, 5), (10, 15)]) + + +if __name__ == "__main__": + unittest.main()