Skip to content

Commit

Permalink
fix selection and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnovsak authored and markotoplak committed Nov 22, 2023
1 parent 6703236 commit 4355650
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 17 deletions.
19 changes: 12 additions & 7 deletions orangecontrib/prototypes/ranktablemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ class ArrayTableModel(PyTableModel):
Other, unlisted methods aren't guaranteed to work and should be used with care.
Also requires access to private members of ``AbstractSortTableModel`` directly;
``__sortInd`` is used for sorting and potential filtering, and ``__init__`` because the
parent implementation immediately wraps a list, which this model does not use.
``__sortInd`` is needed to append new unsorted data, and ``__init__`` is used
because the parent implementation immediately wraps a list, which this model
does not have.
"""
def __init__(self, *args, **kwargs):
super(PyTableModel, self).__init__(*args, **kwargs)
Expand All @@ -31,8 +32,12 @@ def __init__(self, *args, **kwargs):

self._data = None # type: np.ndarray
self._columns = 0
self._rows = 0
self._max_display_rows = self._max_data_rows = MAX_ROWS
self._rows = 0 # current number of rows containing data
self._max_view_rows = MAX_ROWS # maximum number of rows the model/view will display
self._max_data_rows = MAX_ROWS # maximum allowed size for the `_data` array
# ``__len__`` returns _rows: amount of existing data in the model
# ``rowCount`` returns the lowest of `_rows` and `_max_view_rows`:
# how large the model/view thinks it is

def sortInd(self):
return self._AbstractSortTableModel__sortInd
Expand All @@ -53,7 +58,7 @@ def extendSortFrom(self, sorted_rows: int):
self.setSortIndices(indices)

def rowCount(self, parent=QModelIndex()):
return 0 if parent.isValid() else min(self._rows, self._max_display_rows)
return 0 if parent.isValid() else min(self._rows, self._max_view_rows)

def columnCount(self, parent=QModelIndex()):
return 0 if parent.isValid() else self._columns
Expand Down Expand Up @@ -91,10 +96,10 @@ def append(self, rows: list[list[float]]):
if n_rows == 0:
return
n_data = len(self._data)
insert = self._rows < self._max_display_rows
insert = self._rows < self._max_view_rows

if insert:
self.beginInsertRows(QModelIndex(), self._rows, min(self._max_display_rows, self._rows + n_rows) - 1)
self.beginInsertRows(QModelIndex(), self._rows, min(self._max_view_rows, self._rows + n_rows) - 1)

if self._rows + n_rows >= n_data:
n_data = min(max(n_data + n_rows, 2 * n_data), self._max_data_rows)
Expand Down
36 changes: 27 additions & 9 deletions orangecontrib/prototypes/widgets/owinteractions_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(self):
def set_data(self, data):
self.closeContext()
self.clear_messages()
self.selection = []
self.selection = {}
self.data = data
self.pp_data = None
self.n_attrs = 0
Expand Down Expand Up @@ -333,24 +333,40 @@ def toggle(self):
if not self.keep_running:
self.button.setText("Pause")
self.button.repaint()
self.progressBarInit()
self.filter.setEnabled(False)
self.progressBarInit()
self.start(run, self.compute_score, self.row_for_state,
self.iterate_states, self.saved_state,
self.progress, self.state_count())
else:
self.button.setText("Continue")
self.button.repaint()
self.cancel()
self.progressBarFinished()
self.filter.setEnabled(True)
self.cancel()
self._stopped()

def _stopped(self):
self.progressBarFinished()
self._select_default()

def _select_default(self):
n_rows = self.model.rowCount()
if not n_rows:
return

if self.selection:
for i in range(n_rows):
names = {self.model.data(self.model.index(i, 2)),
self.model.data(self.model.index(i, 3))}
if names == self.selection:
self.rank_table.selectRow(i)
break

def _select_first_if_none(self):
if not self.rank_table.selectedIndexes():
self.rank_table.selectRow(0)

def on_selection_changed(self, selected):
self.selection = [self.model.data(ind) for ind in selected.indexes()[-2:]]
self.selection = {self.model.data(ind) for ind in selected.indexes()[-2:]}
self.commit()

def on_filter_changed(self, text):
Expand Down Expand Up @@ -396,9 +412,9 @@ def _iterate_by_feature(self, initial_state):
yield self.feature_index, j

def state_count(self):
if self.feature is None:
if self.feature_index is None:
return self.n_attrs * (self.n_attrs - 1) // 2
return self.n_attrs
return self.n_attrs - 1

def on_partial_result(self, result):
add_to_model, latest_state = result
Expand All @@ -412,7 +428,9 @@ def on_done(self, result):
self.button.setText("Finished")
self.button.setEnabled(False)
self.filter.setEnabled(True)
self._select_first_if_none()
self.keep_running = False
self.saved_state = None
self._stopped()

def send_report(self):
self.report_table("Interactions", self.rank_table)
Expand Down
146 changes: 145 additions & 1 deletion orangecontrib/prototypes/widgets/tests/test_owinteractions_new.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import unittest
from unittest.mock import Mock

import numpy as np
import numpy.testing as npt

from AnyQt.QtCore import QItemSelection

from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import simulate
from Orange.widgets.widget import AttributeList

from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions
from orangecontrib.prototypes.widgets.owinteractions_new import OWInteractions, Heuristic
from orangecontrib.prototypes.interactions import InteractionScorer


class TestOWInteractions(WidgetTest):
Expand Down Expand Up @@ -126,6 +132,24 @@ def test_input_changed(self):
self.process_events()
self.widget.commit.assert_called_once()

def test_saved_selection(self):
"""Check row selection"""
self.send_signal(self.widget.Inputs.data, self.iris)
self.wait_until_finished()
self.process_events()
selection = QItemSelection()
selection.select(self.widget.model.index(2, 0),
self.widget.model.index(2, 3))
self.widget.on_selection_changed(selection)
settings = self.widget.settingsHandler.pack_data(self.widget)

w = self.create_widget(OWInteractions, stored_settings=settings)
self.send_signal(self.widget.Inputs.data, self.iris, widget=w)
self.wait_until_finished(w)
self.process_events()
sel_row = w.rank_table.selectionModel().selectedRows()[0].row()
self.assertEqual(sel_row, 2)

def test_feature_combo(self):
"""Check feature combobox"""
feature_combo = self.widget.controls.feature
Expand All @@ -135,3 +159,123 @@ def test_feature_combo(self):
self.wait_until_stop_blocking()
self.send_signal(self.widget.Inputs.data, self.zoo)
self.assertEqual(len(feature_combo.model()), 17)

def test_select_feature(self):
"""Check feature selection"""
feature_combo = self.widget.controls.feature
self.send_signal(self.widget.Inputs.data, self.iris)
self.wait_until_finished()
self.process_events()
self.assertEqual(self.widget.model.rowCount(), 6)
self.assertSetEqual(
{a.name for a in self.get_output(self.widget.Outputs.features)},
{"sepal width", "sepal length"}
)

simulate.combobox_activate_index(feature_combo, 3)
self.wait_until_finished()
self.process_events()
self.assertEqual(self.widget.model.rowCount(), 3)
self.assertSetEqual(
{a.name for a in self.get_output(self.widget.Outputs.features)},
{"petal length", "sepal width"}
)

simulate.combobox_activate_index(feature_combo, 0)
self.wait_until_finished()
self.process_events()
self.assertEqual(self.widget.model.rowCount(), 6)
self.assertSetEqual(
{a.name for a in self.get_output(self.widget.Outputs.features)},
{"petal length", "sepal width"}
)

def test_send_report(self):
"""Check report"""
self.send_signal(self.widget.Inputs.data, self.iris)
self.widget.report_button.click()
self.wait_until_stop_blocking()
self.send_signal(self.widget.Inputs.data, None)
self.widget.report_button.click()

def test_compute_score(self):
self.widget.scorer = InteractionScorer(self.zoo)
npt.assert_almost_equal(self.widget.compute_score((1, 0)),
[-0.0771, 0.3003, 0.3307], 4)

def test_row_for_state(self):
row = self.widget.row_for_state((-0.2, 0.2, 0.1), (1, 0))
self.assertListEqual(row, [-0.2, 0.1, 1, 0])

def test_iterate_states(self):
self.send_signal(self.widget.Inputs.data, self.iris)
self.assertListEqual(list(self.widget._iterate_all(None)),
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
self.assertListEqual(list(self.widget._iterate_all((1, 0))),
[(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2)])
self.assertListEqual(list(self.widget._iterate_all((2, 1))),
[(2, 1), (3, 0), (3, 1), (3, 2)])
self.widget.feature_index = 2
self.assertListEqual(list(self.widget._iterate_by_feature(None)),
[(2, 0), (2, 1), (2, 3)])
self.assertListEqual(list(self.widget._iterate_by_feature((2, 0))),
[(2, 0), (2, 1), (2, 3)])
self.assertListEqual(list(self.widget._iterate_by_feature((2, 3))),
[(2, 3)])

def test_state_count(self):
self.send_signal(self.widget.Inputs.data, self.iris)
self.assertEqual(self.widget.state_count(), 6)
self.widget.feature_index = 2
self.assertEqual(self.widget.state_count(), 3)


class TestInteractionScorer(unittest.TestCase):
def test_compute_score(self):
"""Check score calculation"""
x = np.array([[1, 1], [0, 1], [1, 1], [0, 0]])
y = np.array([0, 1, 1, 1])
domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3"))
data = Table(domain, x, y)
self.scorer = InteractionScorer(data)
npt.assert_almost_equal(self.scorer(0, 1), -0.1226, 4)
npt.assert_almost_equal(self.scorer.class_entropy, 0.8113, 4)
npt.assert_almost_equal(self.scorer.information_gain[0], 0.3113, 4)
npt.assert_almost_equal(self.scorer.information_gain[1], 0.1226, 4)

def test_nans(self):
"""Check score calculation with nans"""
x = np.array([[1, 1], [0, 1], [1, 1], [0, 0], [1, np.nan], [np.nan, 0], [np.nan, np.nan]])
y = np.array([0, 1, 1, 1, 0, 0, 1])
domain = Domain([DiscreteVariable(str(i)) for i in range(2)], DiscreteVariable("3"))
data = Table(domain, x, y)
self.scorer = InteractionScorer(data)
npt.assert_almost_equal(self.scorer(0, 1), 0.0167, 4)
npt.assert_almost_equal(self.scorer.class_entropy, 0.9852, 4)
npt.assert_almost_equal(self.scorer.information_gain[0], 0.4343, 4)
npt.assert_almost_equal(self.scorer.information_gain[1], 0.0343, 4)


class TestHeuristic(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.zoo = Table("zoo")

def test_heuristic(self):
"""Check attribute pairs returned by heuristic"""
scorer = InteractionScorer(self.zoo)
heuristic = Heuristic(scorer.information_gain,
type=Heuristic.INFO_GAIN)
self.assertListEqual(list(heuristic.get_states(None))[:9],
[(14, 6), (14, 10), (14, 15), (6, 10),
(14, 5), (6, 15), (14, 11), (6, 5), (10, 15)])

states = heuristic.get_states(None)
_ = next(states)
self.assertListEqual(list(heuristic.get_states(next(states)))[:8],
[(14, 10), (14, 15), (6, 10), (14, 5),
(6, 15), (14, 11), (6, 5), (10, 15)])


if __name__ == "__main__":
unittest.main()

0 comments on commit 4355650

Please sign in to comment.