diff --git a/Orange/widgets/evaluate/owtestlearners.py b/Orange/widgets/evaluate/owtestlearners.py index 139f2cb9c0e..9a2d77de92c 100644 --- a/Orange/widgets/evaluate/owtestlearners.py +++ b/Orange/widgets/evaluate/owtestlearners.py @@ -533,7 +533,7 @@ def update_stats_model(self): for stat, scorer in zip(stats, self.scorers): item = QStandardItem() if stat.success: - item.setText("{:.3f}".format(stat.value[0])) + item.setData(float(stat.value[0]), Qt.DisplayRole) else: item.setToolTip(str(stat.exception)) if scorer.name in self.score_table.shown_scores: diff --git a/Orange/widgets/evaluate/tests/test_owtestlearners.py b/Orange/widgets/evaluate/tests/test_owtestlearners.py index 213a2e7e062..cdb9934c0d7 100644 --- a/Orange/widgets/evaluate/tests/test_owtestlearners.py +++ b/Orange/widgets/evaluate/tests/test_owtestlearners.py @@ -277,10 +277,10 @@ def __call__(self, data): # Ensure that the click on header caused an ascending sort # Ascending sort means that wrong model should be listed first self.assertEqual(header.sortIndicatorOrder(), Qt.AscendingOrder) - self.assertEqual(view.model().item(0, 0).text(), "VersicolorLearner") + self.assertEqual(view.model().index(0, 0).data(), "VersicolorLearner") self.send_signal(self.widget.Inputs.test_data, versicolor, wait=5000) - self.assertEqual(view.model().item(0, 0).text(), "SetosaLearner") + self.assertEqual(view.model().index(0, 0).data(), "SetosaLearner") self.widget.hide() @@ -365,10 +365,11 @@ def test_scores_log_reg_advanced(self): [1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn")) ) - self.assertTupleEqual(self._test_scores( - table_train, table_test, LogisticRegressionLearner(), - OWTestLearners.TestOnTest, None), - (0.667, 0.8, 0.8, 0.867, 0.8)) + np.testing.assert_almost_equal( + self._test_scores(table_train, table_test, + LogisticRegressionLearner(), + OWTestLearners.TestOnTest, None), + (2 / 3, 0.8, 0.8, 13 / 15, 0.8)) def test_scores_cross_validation(self): """ diff --git a/Orange/widgets/evaluate/tests/test_utils.py b/Orange/widgets/evaluate/tests/test_utils.py index 4e3db7b3149..a5998f56cfb 100644 --- a/Orange/widgets/evaluate/tests/test_utils.py +++ b/Orange/widgets/evaluate/tests/test_utils.py @@ -3,8 +3,11 @@ import unittest import collections +import numpy as np + from AnyQt.QtWidgets import QMenu -from AnyQt.QtCore import QPoint +from AnyQt.QtGui import QStandardItem +from AnyQt.QtCore import QPoint, Qt from Orange.widgets.evaluate.utils import ScoreTable from Orange.widgets.tests.base import GuiTest @@ -70,5 +73,48 @@ def test_update_shown_columns(self): not header.isSectionHidden(i), msg="error in section {}({})".format(i, name)) + def test_sorting(self): + def order(n=5): + return "".join(model.index(i, 0).data() for i in range(n)) + + score_table = ScoreTable(None) + + data = [ + ["D", 11.0, 15.3], + ["C", 5.0, -15.4], + ["b", 20.0, np.nan], + ["A", None, None], + ["E", "", 0.0] + ] + for data_row in data: + row = [] + for x in data_row: + item = QStandardItem() + if x is not None: + item.setData(x, Qt.DisplayRole) + row.append(item) + score_table.model.appendRow(row) + + model = score_table.view.model() + + model.sort(0, Qt.AscendingOrder) + self.assertEqual(order(), "AbCDE") + + model.sort(0, Qt.DescendingOrder) + self.assertEqual(order(), "EDCbA") + + model.sort(1, Qt.AscendingOrder) + self.assertEqual(order(3), "CDb") + + model.sort(1, Qt.DescendingOrder) + self.assertEqual(order(3), "bDC") + + model.sort(2, Qt.AscendingOrder) + self.assertEqual(order(3), "CED") + + model.sort(2, Qt.DescendingOrder) + self.assertEqual(order(3), "DEC") + + if __name__ == "__main__": unittest.main() diff --git a/Orange/widgets/evaluate/utils.py b/Orange/widgets/evaluate/utils.py index ebe06032777..72c87b328be 100644 --- a/Orange/widgets/evaluate/utils.py +++ b/Orange/widgets/evaluate/utils.py @@ -6,7 +6,8 @@ from AnyQt.QtWidgets import QHeaderView, QStyledItemDelegate, QMenu from AnyQt.QtGui import QStandardItemModel, QStandardItem -from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal +from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal, \ + QSortFilterProxyModel from sklearn.exceptions import UndefinedMetricWarning from Orange.data import Variable, DiscreteVariable, ContinuousVariable @@ -98,6 +99,32 @@ def thunked(): return thunked +class ScoreModel(QSortFilterProxyModel): + def lessThan(self, left, right): + def is_bad(x): + return not isinstance(x, (int, float, str)) \ + or isinstance(x, float) and np.isnan(x) + + left = left.data() + right = right.data() + is_ascending = self.sortOrder() == Qt.AscendingOrder + + # bad entries go below; if both are bad, left remains above + if is_bad(left) or is_bad(right): + return is_bad(right) == is_ascending + + # for data of different types, numbers are at the top + if type(left) is not type(right): + return isinstance(left, float) == is_ascending + + # case insensitive comparison for strings + if isinstance(left, str): + return left.upper() < right.upper() + + # otherwise, compare numbers + return left < right + + class ScoreTable(OWComponent, QObject): shown_scores = \ Setting(set(chain(*BUILTIN_SCORERS_ORDER.values()))) @@ -109,6 +136,12 @@ def sizeHint(self, *args): size = super().sizeHint(*args) return QSize(size.width(), size.height() + 6) + def displayText(self, value, locale): + if isinstance(value, float): + return f"{value:.3f}" + else: + return super().displayText(value, locale) + def __init__(self, master): QObject.__init__(self) OWComponent.__init__(self, master) @@ -125,7 +158,9 @@ def __init__(self, master): self.model = QStandardItemModel(master) self.model.setHorizontalHeaderLabels(["Method"]) - self.view.setModel(self.model) + self.sorted_model = ScoreModel() + self.sorted_model.setSourceModel(self.model) + self.view.setModel(self.sorted_model) self.view.setItemDelegate(self.ItemDelegate()) def _column_names(self):