Skip to content

Commit

Permalink
Test and Score: Sort numerically, not alphabetically
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jul 31, 2019
1 parent 6e7e534 commit 2644ffc
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def update_stats_model(self):
for stat, scorer in zip(stats, self.scorers):
item = QStandardItem()
if stat.success:
item.setText("{:.3f}".format(stat.value[0]))
item.setData(float(stat.value[0]), Qt.DisplayRole)
else:
item.setToolTip(str(stat.exception))
if scorer.name in self.score_table.shown_scores:
Expand Down
13 changes: 7 additions & 6 deletions Orange/widgets/evaluate/tests/test_owtestlearners.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def __call__(self, data):
# Ensure that the click on header caused an ascending sort
# Ascending sort means that wrong model should be listed first
self.assertEqual(header.sortIndicatorOrder(), Qt.AscendingOrder)
self.assertEqual(view.model().item(0, 0).text(), "VersicolorLearner")
self.assertEqual(view.model().index(0, 0).data(), "VersicolorLearner")

self.send_signal(self.widget.Inputs.test_data, versicolor, wait=5000)
self.assertEqual(view.model().item(0, 0).text(), "SetosaLearner")
self.assertEqual(view.model().index(0, 0).data(), "SetosaLearner")

self.widget.hide()

Expand Down Expand Up @@ -365,10 +365,11 @@ def test_scores_log_reg_advanced(self):
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
)

self.assertTupleEqual(self._test_scores(
table_train, table_test, LogisticRegressionLearner(),
OWTestLearners.TestOnTest, None),
(0.667, 0.8, 0.8, 0.867, 0.8))
np.testing.assert_almost_equal(
self._test_scores(table_train, table_test,
LogisticRegressionLearner(),
OWTestLearners.TestOnTest, None),
(2 / 3, 0.8, 0.8, 13 / 15, 0.8))

def test_scores_cross_validation(self):
"""
Expand Down
46 changes: 45 additions & 1 deletion Orange/widgets/evaluate/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import collections

from AnyQt.QtWidgets import QMenu
from AnyQt.QtCore import QPoint
from AnyQt.QtGui import QStandardItem
from AnyQt.QtCore import QPoint, Qt

from Orange.widgets.evaluate.utils import ScoreTable
from Orange.widgets.tests.base import GuiTest
Expand Down Expand Up @@ -70,5 +71,48 @@ def test_update_shown_columns(self):
not header.isSectionHidden(i),
msg="error in section {}({})".format(i, name))

def test_sorting(self):
def order(n=5):
return "".join(model.index(i, 0).data() for i in range(n))

score_table = ScoreTable(None)

data = [
["D", 11.0, 15.3],
["C", 5.0, -15.4],
["b", 20.0, None],
["A", None, None],
["E", "", 0.0]
]
for data_row in data:
row = []
for x in data_row:
item = QStandardItem()
if x is not None:
item.setData(x, Qt.DisplayRole)
row.append(item)
score_table.model.appendRow(row)

model = score_table.view.model()

model.sort(0, Qt.AscendingOrder)
self.assertEqual(order(), "AbCDE")

model.sort(0, Qt.DescendingOrder)
self.assertEqual(order(), "EDCbA")

model.sort(1, Qt.AscendingOrder)
self.assertEqual(order(3), "CDb")

model.sort(1, Qt.DescendingOrder)
self.assertEqual(order(3), "bDC")

model.sort(2, Qt.AscendingOrder)
self.assertEqual(order(3), "CED")

model.sort(2, Qt.DescendingOrder)
self.assertEqual(order(3), "DEC")


if __name__ == "__main__":
unittest.main()
26 changes: 24 additions & 2 deletions Orange/widgets/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from AnyQt.QtWidgets import QHeaderView, QStyledItemDelegate, QMenu
from AnyQt.QtGui import QStandardItemModel, QStandardItem
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal, \
QSortFilterProxyModel
from sklearn.exceptions import UndefinedMetricWarning

from Orange.data import Variable, DiscreteVariable, ContinuousVariable
Expand Down Expand Up @@ -98,6 +99,19 @@ def thunked():
return thunked


class ScoreModel(QSortFilterProxyModel):
def lessThan(self, left, right):
left = left.data()
right = right.data()
if type(left) is not type(right) or left is None or right is None:
# put the one which is not a number (= an error) at the bottom
return isinstance(left, float) == (
self.sortOrder() == Qt.AscendingOrder)
if isinstance(left, str):
return left.upper() < right.upper()
return left < right


class ScoreTable(OWComponent, QObject):
shown_scores = \
Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))
Expand All @@ -109,6 +123,12 @@ def sizeHint(self, *args):
size = super().sizeHint(*args)
return QSize(size.width(), size.height() + 6)

def displayText(self, value, locale):
if isinstance(value, float):
return f"{value:.3f}"
else:
return super().displayText(value, locale)

def __init__(self, master):
QObject.__init__(self)
OWComponent.__init__(self, master)
Expand All @@ -125,7 +145,9 @@ def __init__(self, master):

self.model = QStandardItemModel(master)
self.model.setHorizontalHeaderLabels(["Method"])
self.view.setModel(self.model)
self.sorted_model = ScoreModel()
self.sorted_model.setSourceModel(self.model)
self.view.setModel(self.sorted_model)
self.view.setItemDelegate(self.ItemDelegate())

def _column_names(self):
Expand Down

0 comments on commit 2644ffc

Please sign in to comment.