Skip to content

Commit

Permalink
Test and Score: Sort numerically, not alphabetically
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Aug 17, 2019
1 parent 40916e2 commit 0887a2a
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def update_stats_model(self):
for stat, scorer in zip(stats, self.scorers):
item = QStandardItem()
if stat.success:
item.setText("{:.3f}".format(stat.value[0]))
item.setData(float(stat.value[0]), Qt.DisplayRole)
else:
item.setToolTip(str(stat.exception))
if scorer.name in self.score_table.shown_scores:
Expand Down
13 changes: 7 additions & 6 deletions Orange/widgets/evaluate/tests/test_owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def __call__(self, data):
# Ensure that the click on header caused an ascending sort
# Ascending sort means that wrong model should be listed first
self.assertEqual(header.sortIndicatorOrder(), Qt.AscendingOrder)
self.assertEqual(view.model().item(0, 0).text(), "VersicolorLearner")
self.assertEqual(view.model().index(0, 0).data(), "VersicolorLearner")

self.send_signal(self.widget.Inputs.test_data, versicolor, wait=5000)
self.assertEqual(view.model().item(0, 0).text(), "SetosaLearner")
self.assertEqual(view.model().index(0, 0).data(), "SetosaLearner")

self.widget.hide()

Expand Down Expand Up @@ -365,10 +365,11 @@ def test_scores_log_reg_advanced(self):
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
)

self.assertTupleEqual(self._test_scores(
table_train, table_test, LogisticRegressionLearner(),
OWTestLearners.TestOnTest, None),
(0.667, 0.8, 0.8, 0.867, 0.8))
np.testing.assert_almost_equal(
self._test_scores(table_train, table_test,
LogisticRegressionLearner(),
OWTestLearners.TestOnTest, None),
(2 / 3, 0.8, 0.8, 13 / 15, 0.8))

def test_scores_cross_validation(self):
"""
Expand Down
48 changes: 47 additions & 1 deletion Orange/widgets/evaluate/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import unittest
import collections

import numpy as np

from AnyQt.QtWidgets import QMenu
from AnyQt.QtCore import QPoint
from AnyQt.QtGui import QStandardItem
from AnyQt.QtCore import QPoint, Qt

from Orange.widgets.evaluate.utils import ScoreTable
from Orange.widgets.tests.base import GuiTest
Expand Down Expand Up @@ -70,5 +73,48 @@ def test_update_shown_columns(self):
not header.isSectionHidden(i),
msg="error in section {}({})".format(i, name))

def test_sorting(self):
def order(n=5):
return "".join(model.index(i, 0).data() for i in range(n))

score_table = ScoreTable(None)

data = [
["D", 11.0, 15.3],
["C", 5.0, -15.4],
["b", 20.0, np.nan],
["A", None, None],
["E", "", 0.0]
]
for data_row in data:
row = []
for x in data_row:
item = QStandardItem()
if x is not None:
item.setData(x, Qt.DisplayRole)
row.append(item)
score_table.model.appendRow(row)

model = score_table.view.model()

model.sort(0, Qt.AscendingOrder)
self.assertEqual(order(), "AbCDE")

model.sort(0, Qt.DescendingOrder)
self.assertEqual(order(), "EDCbA")

model.sort(1, Qt.AscendingOrder)
self.assertEqual(order(3), "CDb")

model.sort(1, Qt.DescendingOrder)
self.assertEqual(order(3), "bDC")

model.sort(2, Qt.AscendingOrder)
self.assertEqual(order(3), "CED")

model.sort(2, Qt.DescendingOrder)
self.assertEqual(order(3), "DEC")


if __name__ == "__main__":
unittest.main()
39 changes: 37 additions & 2 deletions Orange/widgets/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from AnyQt.QtWidgets import QHeaderView, QStyledItemDelegate, QMenu
from AnyQt.QtGui import QStandardItemModel, QStandardItem
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal, \
QSortFilterProxyModel
from sklearn.exceptions import UndefinedMetricWarning

from Orange.data import Variable, DiscreteVariable, ContinuousVariable
Expand Down Expand Up @@ -98,6 +99,32 @@ def thunked():
return thunked


class ScoreModel(QSortFilterProxyModel):
def lessThan(self, left, right):
def is_bad(x):
return not isinstance(x, (int, float, str)) \
or isinstance(x, float) and np.isnan(x)

left = left.data()
right = right.data()
is_ascending = self.sortOrder() == Qt.AscendingOrder

# bad entries go below; if both are bad, left remains above
if is_bad(left) or is_bad(right):
return is_bad(right) == is_ascending

# for data of different types, numbers are at the top
if type(left) is not type(right):
return isinstance(left, float) == is_ascending

# case insensitive comparison for strings
if isinstance(left, str):
return left.upper() < right.upper()

# otherwise, compare numbers
return left < right


class ScoreTable(OWComponent, QObject):
shown_scores = \
Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))
Expand All @@ -109,6 +136,12 @@ def sizeHint(self, *args):
size = super().sizeHint(*args)
return QSize(size.width(), size.height() + 6)

def displayText(self, value, locale):
if isinstance(value, float):
return f"{value:.3f}"
else:
return super().displayText(value, locale)

def __init__(self, master):
QObject.__init__(self)
OWComponent.__init__(self, master)
Expand All @@ -125,7 +158,9 @@ def __init__(self, master):

self.model = QStandardItemModel(master)
self.model.setHorizontalHeaderLabels(["Method"])
self.view.setModel(self.model)
self.sorted_model = ScoreModel()
self.sorted_model.setSourceModel(self.model)
self.view.setModel(self.sorted_model)
self.view.setItemDelegate(self.ItemDelegate())

def _column_names(self):
Expand Down

0 comments on commit 0887a2a

Please sign in to comment.