Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Test and Score: Sort numerically, not alphabetically #3951

Merged
merged 1 commit into from
Aug 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def update_stats_model(self):
for stat, scorer in zip(stats, self.scorers):
item = QStandardItem()
if stat.success:
item.setText("{:.3f}".format(stat.value[0]))
item.setData(float(stat.value[0]), Qt.DisplayRole)
else:
item.setToolTip(str(stat.exception))
if scorer.name in self.score_table.shown_scores:
Expand Down
13 changes: 7 additions & 6 deletions Orange/widgets/evaluate/tests/test_owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def __call__(self, data):
# Ensure that the click on header caused an ascending sort
# Ascending sort means that wrong model should be listed first
self.assertEqual(header.sortIndicatorOrder(), Qt.AscendingOrder)
self.assertEqual(view.model().item(0, 0).text(), "VersicolorLearner")
self.assertEqual(view.model().index(0, 0).data(), "VersicolorLearner")

self.send_signal(self.widget.Inputs.test_data, versicolor, wait=5000)
self.assertEqual(view.model().item(0, 0).text(), "SetosaLearner")
self.assertEqual(view.model().index(0, 0).data(), "SetosaLearner")

self.widget.hide()

Expand Down Expand Up @@ -365,10 +365,11 @@ def test_scores_log_reg_advanced(self):
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
)

self.assertTupleEqual(self._test_scores(
table_train, table_test, LogisticRegressionLearner(),
OWTestLearners.TestOnTest, None),
(0.667, 0.8, 0.8, 0.867, 0.8))
np.testing.assert_almost_equal(
self._test_scores(table_train, table_test,
LogisticRegressionLearner(),
OWTestLearners.TestOnTest, None),
(2 / 3, 0.8, 0.8, 13 / 15, 0.8))

def test_scores_cross_validation(self):
"""
Expand Down
48 changes: 47 additions & 1 deletion Orange/widgets/evaluate/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import unittest
import collections

import numpy as np

from AnyQt.QtWidgets import QMenu
from AnyQt.QtCore import QPoint
from AnyQt.QtGui import QStandardItem
from AnyQt.QtCore import QPoint, Qt

from Orange.widgets.evaluate.utils import ScoreTable
from Orange.widgets.tests.base import GuiTest
Expand Down Expand Up @@ -70,5 +73,48 @@ def test_update_shown_columns(self):
not header.isSectionHidden(i),
msg="error in section {}({})".format(i, name))

def test_sorting(self):
def order(n=5):
return "".join(model.index(i, 0).data() for i in range(n))

score_table = ScoreTable(None)

data = [
["D", 11.0, 15.3],
["C", 5.0, -15.4],
["b", 20.0, np.nan],
["A", None, None],
["E", "", 0.0]
]
for data_row in data:
row = []
for x in data_row:
item = QStandardItem()
if x is not None:
item.setData(x, Qt.DisplayRole)
row.append(item)
score_table.model.appendRow(row)

model = score_table.view.model()

model.sort(0, Qt.AscendingOrder)
self.assertEqual(order(), "AbCDE")

model.sort(0, Qt.DescendingOrder)
self.assertEqual(order(), "EDCbA")

model.sort(1, Qt.AscendingOrder)
self.assertEqual(order(3), "CDb")

model.sort(1, Qt.DescendingOrder)
self.assertEqual(order(3), "bDC")

model.sort(2, Qt.AscendingOrder)
self.assertEqual(order(3), "CED")

model.sort(2, Qt.DescendingOrder)
self.assertEqual(order(3), "DEC")


if __name__ == "__main__":
unittest.main()
39 changes: 37 additions & 2 deletions Orange/widgets/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from AnyQt.QtWidgets import QHeaderView, QStyledItemDelegate, QMenu
from AnyQt.QtGui import QStandardItemModel, QStandardItem
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal, \
QSortFilterProxyModel
from sklearn.exceptions import UndefinedMetricWarning

from Orange.data import Variable, DiscreteVariable, ContinuousVariable
Expand Down Expand Up @@ -98,6 +99,32 @@ def thunked():
return thunked


class ScoreModel(QSortFilterProxyModel):
def lessThan(self, left, right):
def is_bad(x):
return not isinstance(x, (int, float, str)) \
or isinstance(x, float) and np.isnan(x)

left = left.data()
right = right.data()
is_ascending = self.sortOrder() == Qt.AscendingOrder

# bad entries go below; if both are bad, left remains above
if is_bad(left) or is_bad(right):
return is_bad(right) == is_ascending

# for data of different types, numbers are at the top
if type(left) is not type(right):
return isinstance(left, float) == is_ascending

# case insensitive comparison for strings
if isinstance(left, str):
return left.upper() < right.upper()

# otherwise, compare numbers
return left < right


class ScoreTable(OWComponent, QObject):
shown_scores = \
Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))
Expand All @@ -109,6 +136,12 @@ def sizeHint(self, *args):
size = super().sizeHint(*args)
return QSize(size.width(), size.height() + 6)

def displayText(self, value, locale):
if isinstance(value, float):
return f"{value:.3f}"
else:
return super().displayText(value, locale)

def __init__(self, master):
QObject.__init__(self)
OWComponent.__init__(self, master)
Expand All @@ -125,7 +158,9 @@ def __init__(self, master):

self.model = QStandardItemModel(master)
self.model.setHorizontalHeaderLabels(["Method"])
self.view.setModel(self.model)
self.sorted_model = ScoreModel()
self.sorted_model.setSourceModel(self.model)
self.view.setModel(self.sorted_model)
self.view.setItemDelegate(self.ItemDelegate())

def _column_names(self):
Expand Down