Skip to content

Commit

Permalink
Merge pull request #4261 from janezd/comparison-of-models
Browse files Browse the repository at this point in the history
[ENH] Test & Score: Add comparison of models
  • Loading branch information
VesnaT authored Jan 24, 2020
2 parents d9edad9 + a27cce6 commit 64f0e48
Show file tree
Hide file tree
Showing 4 changed files with 439 additions and 8 deletions.
215 changes: 209 additions & 6 deletions Orange/widgets/evaluate/owtestlearners.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint doesn't understand the Settings magic
# pylint: disable=invalid-sequence-index
# pylint: disable=too-many-lines,too-many-instance-attributes
import abc
import enum
import logging
Expand All @@ -9,14 +10,17 @@

from concurrent.futures import Future
from collections import OrderedDict, namedtuple
from itertools import count
from typing import Any, Optional, List, Dict, Callable

import numpy as np
import baycomp

from AnyQt import QtGui
from AnyQt.QtGui import QStandardItem
from AnyQt.QtCore import Qt, QSize, QThread
from AnyQt.QtCore import pyqtSlot as Slot
from AnyQt.QtGui import QStandardItem, QDoubleValidator
from AnyQt.QtWidgets import QHeaderView, QTableWidget, QLabel

from Orange.base import Learner
import Orange.classification
Expand All @@ -35,7 +39,7 @@
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.utils.concurrent import ThreadExecutor, TaskState
from Orange.widgets.widget import OWWidget, Msg, Input, Output

from orangewidget.utils.itemmodels import PyListModel

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,6 +179,10 @@ class Outputs:
fold_feature = settings.ContextSetting(None)
fold_feature_selected = settings.ContextSetting(False)

use_rope = settings.Setting(False)
rope = settings.Setting(0.1)
comparison_criterion = settings.Setting(0, schema_only=True)

TARGET_AVERAGE = "(Average over classes)"
class_selection = settings.ContextSetting(TARGET_AVERAGE)

Expand Down Expand Up @@ -216,6 +224,7 @@ def __init__(self):
self.train_data_missing_vals = False
self.test_data_missing_vals = False
self.scorers = []
self.__pending_comparison_criterion = self.comparison_criterion

#: An Ordered dictionary with current inputs and their testing results.
self.learners = OrderedDict() # type: Dict[Any, Input]
Expand Down Expand Up @@ -275,13 +284,55 @@ def __init__(self):
callback=self._on_target_class_changed,
contentsLength=8)

self.modcompbox = box = gui.vBox(self.controlArea, "Model Comparison")
gui.comboBox(
box, self, "comparison_criterion", model=PyListModel(),
callback=self.update_comparison_table)

hbox = gui.hBox(box)
gui.checkBox(hbox, self, "use_rope",
"Negligible difference: ",
callback=self._on_use_rope_changed)
gui.lineEdit(hbox, self, "rope", validator=QDoubleValidator(),
controlWidth=70, callback=self.update_comparison_table,
alignment=Qt.AlignRight)
self.controls.rope.setEnabled(self.use_rope)

gui.rubber(self.controlArea)
self.score_table = ScoreTable(self)
self.score_table.shownScoresChanged.connect(self.update_stats_model)
view = self.score_table.view
view.setSizeAdjustPolicy(view.AdjustToContents)

box = gui.vBox(self.mainArea, "Evaluation Results")
box.layout().addWidget(self.score_table.view)

self.compbox = box = gui.vBox(self.mainArea, box="Model comparison")
table = self.comparison_table = QTableWidget(
wordWrap=False, editTriggers=QTableWidget.NoEditTriggers,
selectionMode=QTableWidget.NoSelection)
table.setSizeAdjustPolicy(table.AdjustToContents)
header = table.verticalHeader()
header.setSectionResizeMode(QHeaderView.Fixed)
header.setSectionsClickable(False)

header = table.horizontalHeader()
header.setTextElideMode(Qt.ElideRight)
header.setDefaultAlignment(Qt.AlignCenter)
header.setSectionsClickable(False)
header.setStretchLastSection(False)
header.setSectionResizeMode(QHeaderView.ResizeToContents)
avg_width = self.fontMetrics().averageCharWidth()
header.setMinimumSectionSize(8 * avg_width)
header.setMaximumSectionSize(15 * avg_width)
header.setDefaultSectionSize(15 * avg_width)
box.layout().addWidget(table)
box.layout().addWidget(QLabel(
"<small>Table shows probabilities that the score for the model in "
"the row is higher than that of the model in the column. "
"Small numbers show the probability that the difference is "
"negligible.</small>", wordWrap=True))

@staticmethod
def sizeHint():
return QSize(780, 1)
Expand Down Expand Up @@ -436,10 +487,32 @@ def _which_missing_data(self):
# - we don't gain much with it
# - it complicates the unit tests
def _update_scorers(self):
if self.data is None or self.data.domain.class_var is None:
self.scorers = []
return
self.scorers = usable_scorers(self.data.domain.class_var)
if self.data and self.data.domain.class_var:
new_scorers = usable_scorers(self.data.domain.class_var)
else:
new_scorers = []
# Don't unnecessarily reset the model because this would always reset
# comparison_criterion; we alse set it explicitly, though, for clarity
if new_scorers != self.scorers:
self.scorers = new_scorers
self.controls.comparison_criterion.model()[:] = \
[scorer.long_name or scorer.name for scorer in self.scorers]
self.comparison_criterion = 0
if self.__pending_comparison_criterion is not None:
# Check for the unlikely case that some scorers have been removed
# from modules
if self.__pending_comparison_criterion < len(self.scorers):
self.comparison_criterion = self.__pending_comparison_criterion
self.__pending_comparison_criterion = None
self._update_compbox_title()

def _update_compbox_title(self):
criterion = self.comparison_criterion
if criterion < len(self.scorers):
scorer = self.scorers[criterion]()
self.compbox.setTitle(f"Model Comparison by {scorer.name}")
else:
self.compbox.setTitle(f"Model Comparison")

@Inputs.preprocessor
def set_preprocessor(self, preproc):
Expand All @@ -453,6 +526,7 @@ def handleNewSignals(self):
"""Reimplemented from OWWidget.handleNewSignals."""
self._update_class_selection()
self.score_table.update_header(self.scorers)
self._update_view_enabled()
self.update_stats_model()
if self.__needupdate:
self.__update()
Expand All @@ -470,9 +544,19 @@ def shuffle_split_changed(self):
self._param_changed()

def _param_changed(self):
self.modcompbox.setEnabled(self.resampling == OWTestLearners.KFold)
self._update_view_enabled()
self._invalidate()
self.__update()

def _update_view_enabled(self):
self.comparison_table.setEnabled(
self.resampling == OWTestLearners.KFold
and len(self.learners) > 1
and self.data is not None)
self.score_table.view.setEnabled(
self.data is not None)

def update_stats_model(self):
# Update the results_model with up to date scores.
# Note: The target class specific scores (if requested) are
Expand All @@ -494,8 +578,10 @@ def update_stats_model(self):
errors = []
has_missing_scores = False

names = []
for key, slot in self.learners.items():
name = learner_name(slot.learner)
names.append(name)
head = QStandardItem(name)
head.setData(key, Qt.UserRole)
results = slot.results
Expand Down Expand Up @@ -558,10 +644,123 @@ def update_stats_model(self):
header.sortIndicatorSection(),
header.sortIndicatorOrder()
)
self._set_comparison_headers(names)

self.error("\n".join(errors), shown=bool(errors))
self.Warning.scores_not_computed(shown=has_missing_scores)

def _on_use_rope_changed(self):
self.controls.rope.setEnabled(self.use_rope)
self.update_comparison_table()

def update_comparison_table(self):
self.comparison_table.clearContents()
slots = self._successful_slots()
if not (slots and self.scorers):
return
names = [learner_name(slot.learner) for slot in slots]
self._set_comparison_headers(names)
if self.resampling == OWTestLearners.KFold:
scores = self._scores_by_folds(slots)
self._fill_table(names, scores)

def _successful_slots(self):
model = self.score_table.model
proxy = self.score_table.sorted_model

keys = (model.data(proxy.mapToSource(proxy.index(row, 0)), Qt.UserRole)
for row in range(proxy.rowCount()))
slots = [slot for slot in (self.learners[key] for key in keys)
if slot.results is not None and slot.results.success]
return slots

def _set_comparison_headers(self, names):
table = self.comparison_table
try:
# Prevent glitching during update
table.setUpdatesEnabled(False)
header = table.horizontalHeader()
if len(names) > 2:
header.setSectionResizeMode(QHeaderView.Stretch)
else:
header.setSectionResizeMode(QHeaderView.Fixed)
table.setRowCount(len(names))
table.setColumnCount(len(names))
table.setVerticalHeaderLabels(names)
table.setHorizontalHeaderLabels(names)
finally:
table.setUpdatesEnabled(True)

def _scores_by_folds(self, slots):
scorer = self.scorers[self.comparison_criterion]()
self._update_compbox_title()
if scorer.is_binary:
if self.class_selection != self.TARGET_AVERAGE:
class_var = self.data.domain.class_var
target_index = class_var.values.index(self.class_selection)
kw = dict(target=target_index)
else:
kw = dict(average='weighted')
else:
kw = {}

def call_scorer(results):
def thunked():
return scorer.scores_by_folds(results.value, **kw).flatten()

return thunked

scores = [Try(call_scorer(slot.results)) for slot in slots]
scores = [score.value if score.success else None for score in scores]
# `None in scores doesn't work -- these are np.arrays)
if any(score is None for score in scores):
self.Warning.scores_not_computed()
return scores

def _fill_table(self, names, scores):
table = self.comparison_table
for row, row_name, row_scores in zip(count(), names, scores):
for col, col_name, col_scores in zip(range(row), names, scores):
if row_scores is None or col_scores is None:
continue
if self.use_rope and self.rope:
p0, rope, p1 = baycomp.two_on_single(
row_scores, col_scores, self.rope)
if np.isnan(p0) or np.isnan(rope) or np.isnan(p1):
self._set_cells_na(table, row, col)
continue
self._set_cell(table, row, col,
f"{p0:.3f}<br/><small>{rope:.3f}</small>",
f"p({row_name} > {col_name}) = {p0:.3f}\n"
f"p({row_name} = {col_name}) = {rope:.3f}")
self._set_cell(table, col, row,
f"{p1:.3f}<br/><small>{rope:.3f}</small>",
f"p({col_name} > {row_name}) = {p1:.3f}\n"
f"p({col_name} = {row_name}) = {rope:.3f}")
else:
p0, p1 = baycomp.two_on_single(row_scores, col_scores)
if np.isnan(p0) or np.isnan(p1):
self._set_cells_na(table, row, col)
continue
self._set_cell(table, row, col,
f"{p0:.3f}",
f"p({row_name} > {col_name}) = {p0:.3f}")
self._set_cell(table, col, row,
f"{p1:.3f}",
f"p({col_name} > {row_name}) = {p1:.3f}")

@classmethod
def _set_cells_na(cls, table, row, col):
cls._set_cell(table, row, col, "NA", "comparison cannot be computed")
cls._set_cell(table, col, row, "NA", "comparison cannot be computed")

@staticmethod
def _set_cell(table, row, col, label, tooltip):
item = QLabel(label)
item.setToolTip(tooltip)
item.setAlignment(Qt.AlignCenter)
table.setCellWidget(row, col, item)

def _update_class_selection(self):
self.class_selection_combo.setCurrentIndex(-1)
self.class_selection_combo.clear()
Expand All @@ -585,6 +784,7 @@ def _update_class_selection(self):

def _on_target_class_changed(self):
self.update_stats_model()
self.update_comparison_table()

def _invalidate(self, which=None):
self.cancel()
Expand All @@ -611,6 +811,8 @@ def _invalidate(self, which=None):
item.setData(None, Qt.DisplayRole)
item.setData(None, Qt.ToolTipRole)

self.comparison_table.clearContents()

self.__needupdate = True

def commit(self):
Expand Down Expand Up @@ -866,6 +1068,7 @@ def __task_complete(self, f: 'Future[Results]'):

self.score_table.update_header(self.scorers)
self.update_stats_model()
self.update_comparison_table()

self.commit()

Expand Down
Loading

0 comments on commit 64f0e48

Please sign in to comment.