Skip to content

Commit

Permalink
Merge pull request #5138 from janezd/roc-point-fixes
Browse files Browse the repository at this point in the history
[FIX] ROC shows all points, including the last
  • Loading branch information
BlazZupan authored Jan 21, 2021
2 parents 0a83c6c + 1eca4d6 commit 9c593cc
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 13 deletions.
64 changes: 52 additions & 12 deletions Orange/widgets/evaluate/owrocanalysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""
ROC Analysis Widget
-------------------
"""
import operator
from functools import reduce, wraps
from collections import namedtuple, deque, OrderedDict
Expand All @@ -11,7 +6,7 @@
import sklearn.metrics as skl_metrics

from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, \
QToolTip, QSizePolicy
QToolTip
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, \
QCursor, QFontMetrics
from AnyQt.QtCore import Qt, QSize
Expand All @@ -21,13 +16,15 @@
from Orange.widgets import widget, gui, settings
from Orange.widgets.evaluate.contexthandlers import \
EvaluationResultsContextHandler
from Orange.widgets.evaluate.utils import \
check_results_adequacy, results_for_preview
from Orange.widgets.evaluate.utils import check_results_adequacy
from Orange.widgets.utils import colorpalettes
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import Input
from Orange.widgets import report

from Orange.widgets.evaluate.utils import results_for_preview
from Orange.evaluation.testing import Results


#: Points on a ROC curve
ROCPoints = namedtuple(
Expand Down Expand Up @@ -93,11 +90,11 @@ def roc_data_from_results(results, clf_index, target):
:rval ROCData:
A instance holding the computed curves.
"""
merged = roc_curve_for_fold(results, slice(0, -1), clf_index, target)
merged = roc_curve_for_fold(results, ..., clf_index, target)
merged_curve = ROCCurve(ROCPoints(*merged),
ROCPoints(*roc_curve_convex_hull(merged)))

folds = results.folds if results.folds is not None else [slice(0, -1)]
folds = results.folds if results.folds is not None else [...]
fold_curves = []
for fold in folds:
points = roc_curve_for_fold(results, fold, clf_index, target)
Expand Down Expand Up @@ -413,11 +410,13 @@ def __init__(self):
axis.setTickFont(tickfont)
axis.setPen(pen)
axis.setLabel("FP Rate (1-Specificity)")
axis.setGrid(16)

axis = self.plot.getAxis("left")
axis.setTickFont(tickfont)
axis.setPen(pen)
axis.setLabel("TP Rate (Sensitivity)")
axis.setGrid(16)

self.plot.showGrid(True, True, alpha=0.1)
self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)
Expand Down Expand Up @@ -621,6 +620,8 @@ def no_averaging():
if self.roc_averaging == OWROCAnalysis.Merge:
self._update_perf_line()

self._update_axes_ticks()

warning = ""
if not all(c.is_valid for c in hull_curves):
if any(c.is_valid for c in hull_curves):
Expand All @@ -629,6 +630,22 @@ def no_averaging():
warning = "All ROC curves are undefined"
self.warning(warning)

def _update_axes_ticks(self):
def enumticks(a):
a = np.unique(a)
if len(a) > 15:
return None
return [[(x, f"{x:.2f}") for x in a[::-1]]]

data = self.curve_data(self.target_index, self.selected_classifiers[0])
points = data.merged.points

axis = self.plot.getAxis("bottom")
axis.setTicks(enumticks(points.fpr))

axis = self.plot.getAxis("left")
axis.setTicks(enumticks(points.tpr))

def _on_mouse_moved(self, pos):
target = self.target_index
selected = self.selected_classifiers
Expand Down Expand Up @@ -802,10 +819,19 @@ def roc_curve_for_fold(res, fold, clf_idx, target):
return np.array([]), np.array([]), np.array([])

fold_probs = res.probabilities[clf_idx][fold][:, target]
return skl_metrics.roc_curve(
fold_actual, fold_probs, pos_label=target
drop_intermediate = len(fold_actual) > 20
fpr, tpr, thresholds = skl_metrics.roc_curve(
fold_actual, fold_probs, pos_label=target,
drop_intermediate=drop_intermediate
)

# skl sets the first threshold to the highest threshold in the data + 1
# since we deal with probabilities, we (carefully) set it to 1
# Unrelated comparisons, thus pylint: disable=chained-comparison
if len(thresholds) > 1 and thresholds[1] <= 1:
thresholds[0] = 1
return fpr, tpr, thresholds


def roc_curve_vertical_average(curves, samples=10):
if not curves:
Expand Down Expand Up @@ -969,5 +995,19 @@ def roc_iso_performance_slope(fp_cost, fn_cost, p):
return (fp_cost * (1. - p)) / (fn_cost * p)


def _create_results(): # pragma: no cover
probs1 = [0.984, 0.907, 0.881, 0.865, 0.815, 0.741, 0.735, 0.635,
0.582, 0.554, 0.413, 0.317, 0.287, 0.225, 0.216, 0.183]
probs = np.array([[[1 - x, x] for x in probs1]])
preds = (probs > 0.5).astype(float)
return Results(
data=Orange.data.Table("heart_disease")[:16],
row_indices=np.arange(16),
actual=np.array(list(map(int, "1100111001001000"))),
probabilities=probs, predicted=preds
)


if __name__ == "__main__": # pragma: no cover
# WidgetPreview(OWROCAnalysis).run(_create_results())
WidgetPreview(OWROCAnalysis).run(results_for_preview())
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/tests/test_owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_tooltips(self):
pos = view.mapFromScene(pos)
mouseMove(view.viewport(), pos)
(_, text), _ = show_text.call_args
self.assertIn("(#1) 1.800\n(#2) 1.893", text)
self.assertIn("(#1) 1.000\n(#2) 1.000", text)

# test that cache is invalidated when changing averaging mode
self.widget.roc_averaging = OWROCAnalysis.Threshold
Expand Down

0 comments on commit 9c593cc

Please sign in to comment.