Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Calibration plot (add performance curves) and a new Calibrated Learner widget #3881

Merged
merged 21 commits into from
Jul 12, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
35a6f4b
Calibration plot: Add plots of ca, sens/spec, prec/recall, ppv/npv
janezd Jun 13, 2019
2fa1750
Calibration plot: Add threshold line
janezd Jun 13, 2019
d47b68b
Calibration plot: Refactor computation of metrics
janezd Jun 13, 2019
585feb2
Testing: Keep 2d array of models when splitting Results by models
janezd Jun 13, 2019
7b876e6
Test Learners: Store models when there is just one; properly stack them
janezd Jun 13, 2019
93b7a72
classification: Add ModelWithThreshold
janezd Jun 13, 2019
ff67b49
Calibration plot: Output selected model
janezd Jun 13, 2019
a4424fb
Orange.evaluation.performance_curves: Add module for computation of p…
janezd Jun 16, 2019
6024897
Calibration plot: Use Orange.evaluation.testing.performance_curves to…
janezd Jun 16, 2019
1cfbeec
Calibration plot: Fix selected model output
janezd Jun 17, 2019
f742ff9
OWLearnerWidget: Let default name appear as placeholder. This allows …
janezd Jun 17, 2019
c5d070d
evaluations.testing: Minor fixes in unit tests
janezd Jun 17, 2019
557fa2e
OWTestLearners: Skip inactive signals (e.g. learner widget outputs None)
janezd Jun 17, 2019
1a8b013
Calibrated Learner: Add widget
janezd Jun 17, 2019
6ac1db1
Calibration plot: Add context settings
janezd Jun 17, 2019
2edcb39
OWCalibration Plot: Unit tests and some fixes
janezd Jun 18, 2019
2049afa
Calibration plot: Test missing probabilities and single classes
janezd Jun 19, 2019
04d05f4
Calibration plot: Minor fixes
janezd Jun 24, 2019
6695ee9
Calibrated Learner: Fix report
janezd Jun 28, 2019
65c69e2
Calibrated Learner: Add icon
janezd Jun 28, 2019
864d7b5
Calibration plot: Nicer report
janezd Jun 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 106 additions & 35 deletions Orange/widgets/evaluate/owcalibrationplot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections import namedtuple
from functools import partial

import numpy as np

from AnyQt.QtCore import Qt
from AnyQt.QtWidgets import QListWidget
from AnyQt.QtCore import Qt, QSize
from AnyQt.QtWidgets import QListWidget, QSizePolicy

import pyqtgraph as pg

Expand All @@ -16,6 +17,10 @@
from Orange.widgets.widget import Input
from Orange.widgets import report

metric_definition = namedtuple(
"metric_definition",
("name", "function", "short_names", "explanation"))


class OWCalibrationPlot(widget.OWWidget):
name = "Calibration Plot"
Expand All @@ -36,29 +41,40 @@ class Warning(widget.OWWidget.Warning):
score = settings.Setting(0)
fold_curves = settings.Setting(False)
display_rug = settings.Setting(True)
threshold = settings.Setting(0.5)

graph_name = "plot"

def __init__(self):
super().__init__()

self.results = None
self.scores = None
self.classifier_names = []
self.colors = []

box = gui.vBox(self.controlArea, "Target Class")
box = gui.vBox(self.controlArea, box="Settings")
self.target_cb = gui.comboBox(
box, self, "target_index", callback=self._replot, contentsLength=8)
gui.checkBox(box, self, "display_rug", "Show rug",
callback=self._on_display_rug_changed)
box, self, "target_index", label="Target:",
orientation=Qt.Horizontal, callback=self._replot, contentsLength=8)
gui.checkBox(
box, self, "display_rug", "Show rug",
callback=self._on_display_rug_changed)
gui.checkBox(
box, self, "fold_curves", "Curves for individual folds",
callback=self._replot)

self.classifiers_list_box = gui.listBox(
self.controlArea, self, "selected_classifiers", "classifier_names",
box="Classifier", selectionMode=QListWidget.ExtendedSelection,
sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred),
sizeHint=QSize(150, 40),
callback=self._replot)

box = gui.vBox(self.controlArea, "Metrics")
combo = gui.comboBox(
box, self, "score", items=(x[0] for x in self.Metrics),
box, self, "score", items=(metric.name for metric in self.Metrics),
callback=self.score_changed)
gui.checkBox(
box, self, "fold_curves", "Curves for individual folds",
callback=self._replot)

self.explanation = gui.widgetLabel(
box, wordWrap=True, fixedWidth=combo.sizeHint().width())
Expand All @@ -67,16 +83,18 @@ def __init__(self):
font.setPointSizeF(0.85 * font.pointSizeF())
self.explanation.setFont(font)

self.classifiers_list_box = gui.listBox(
self.controlArea, self, "selected_classifiers", "classifier_names",
box="Classifier", selectionMode=QListWidget.ExtendedSelection,
callback=self._replot)
box = gui.widgetBox(self.controlArea, "Info")
self.info_label = gui.widgetLabel(box)

self.plotview = pg.GraphicsView(background="w")
self.plot = pg.PlotItem(enableMenu=False)
self.plot.setMouseEnabled(False, False)
self.plot.hideButtons()

for axis_name in ("bottom", "left"):
axis = self.plot.getAxis(axis_name)
axis.setPen(pg.mkPen(color=0.0))

self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)
self.plotview.setCentralItem(self.plot)

Expand Down Expand Up @@ -110,7 +128,7 @@ def score_changed(self):
self._replot()

def _set_explanation(self):
explanation = self.Metrics[self.score][2]
explanation = self.Metrics[self.score].explanation
if explanation:
self.explanation.setText(explanation)
self.explanation.show()
Expand All @@ -122,7 +140,7 @@ def _set_explanation(self):
else "Threshold probability to classify as positive")

axis = self.plot.getAxis("left")
axis.setLabel(self.Metrics[self.score][0])
axis.setLabel(self.Metrics[self.score].name)

def _initialize(self, results):
N = len(results.predicted)
Expand All @@ -149,7 +167,7 @@ def plot_metrics(ytrue, probs, metrics, pen_args):
probs = probs[sortind]
ytrue = ytrue[sortind]
fn = np.cumsum(ytrue)
metrics(ytrue, probs, fn, pen_args)
return probs, metrics(ytrue, probs, fn, pen_args)

def _rug(self, ytrue, probs, _fn, pen_args):
color = pen_args["pen"].color()
Expand Down Expand Up @@ -208,6 +226,7 @@ def _sens_spec_curve(self, ytrue, probs, fn, pen_args):
spec = (np.arange(1, n + 1) - fn) / real_neg
self.plot.plot(probs, sens, **pen_args)
self.plot.plot(probs, spec, **pen_args)
return sens, spec

def _pr_curve(self, ytrue, probs, fn, pen_args):
# precision = tp / pred_pos = (real_pos - fn[i]) / (n - i)
Expand All @@ -219,6 +238,7 @@ def _pr_curve(self, ytrue, probs, fn, pen_args):
recall = 1 - fn / real_pos
self.plot.plot(probs[:-1], prec, **pen_args)
self.plot.plot(probs[:-1], recall, **pen_args)
return prec, recall

def _ppv_npv_curve(self, ytrue, probs, fn, pen_args):
# ppv = tp / pred_pos = (real_pos - fn[i]) / (n - i)
Expand All @@ -230,12 +250,14 @@ def _ppv_npv_curve(self, ytrue, probs, fn, pen_args):
npv = 1 - fn / np.arange(1, n)
self.plot.plot(probs[:-1], ppv, **pen_args)
self.plot.plot(probs[:-1], npv, **pen_args)
return ppv, npv

def _setup_plot(self):
target = self.target_index
results = self.results
metrics = partial(self.Metrics[self.score][1], self)
metrics = partial(self.Metrics[self.score].function, self)
plot_folds = self.fold_curves and results.folds is not None
self.scores = []

ytrue = results.actual == target
for clsf in self.selected_classifiers:
Expand All @@ -246,7 +268,9 @@ def _setup_plot(self):
shadowPen=pg.mkPen(color.lighter(160),
width=3 + 5 * plot_folds),
antiAlias=True)
self.plot_metrics(ytrue, probs, metrics, pen_args)
self.scores.append(
(self.classifier_names[clsf],
self.plot_metrics(ytrue, probs, metrics, pen_args)))

if self.display_rug:
self.plot_metrics(ytrue, probs, self._rug, pen_args)
Expand All @@ -265,10 +289,54 @@ def _replot(self):
self.plot.clear()
if self.results is not None:
self._setup_plot()
self.line = pg.InfiniteLine(
pos=self.threshold, movable=True,
pen=pg.mkPen(color="k", style=Qt.DashLine, width=2),
hoverPen=pg.mkPen(color="k", style=Qt.DashLine, width=3),
bounds=(0, 1),
)
self.line.sigPositionChanged.connect(self.threshold_change)
self.line.sigPositionChangeFinished.connect(self.threshold_change_done)
self.plot.addItem(self.line)
self._update_info()


def _on_display_rug_changed(self):
self._replot()

def threshold_change(self):
self.threshold = round(self.line.pos().x(), 2)
self.line.setPos(self.threshold)
self._update_info()

def _update_info(self):

text = f"""<table>
<tr>
<th align='right'>Threshold: p=</th>
<td colspan='4'>{self.threshold:.2f}<br/></td>
</tr>"""
if self.scores is not None:
short_names = self.Metrics[self.score].short_names
if short_names:
text += f"""<tr>
<th></th>
{"<td></td>".join(f"<td align='right'>{n}</td>"
for n in short_names)}
</tr>"""
for name, (probs, curves) in self.scores:
ind = min(np.searchsorted(probs, self.threshold),
len(probs) - 1)
text += f"<tr><th align='right'>{name}:</th>"
text += "<td>/</td>".join(f'<td>{curve[ind]:.3f}</td>'
for curve in curves)
text += "</tr>"
text += "<table>"
self.info_label.setText(text)

def threshold_change_done(self):
...

def send_report(self):
if self.results is None:
return
Expand All @@ -278,22 +346,25 @@ def send_report(self):
self.report_plot()
self.report_caption(caption)

Metrics = [
("Actual probability", _prob_curve, ""),
("Classification accuracy", _ca_curve, ""),
("Sensitivity & Specificity", _sens_spec_curve,
"<b>Sensitivity</b> (falling) is the proportion of correctly detected "
"positive instances (TP / P), and <b>specificity</b> (rising) is the "
"proportion of detected negative instances (TP / N)."),
("Precision & Recall", _pr_curve,
"<b>Precision</b> (rising) is the fraction of retrieved instances "
"that are relevant, TP / (TP + FP), and <b>recall</b> (falling) is "
"the proportion of discovered relevant instances, TP / P."),
("Pos & Neg predictive value", _ppv_npv_curve,
"<b>Positive predictive value</b> (rising) is the proportion of "
"correct positives, TP / (TP + FP), and <b>negative predictive "
"value</b> the proportion of correct negatives, TN / (TN + FN)."),
]
Metrics = [metric_definition(*args) for args in (
("Actual probability", _prob_curve, (), ""),
("Classification accuracy", _ca_curve, (), ""),
("Sensitivity and specificity", _sens_spec_curve, ("sens", "spec"),
"<p><b>Sensitivity</b> (falling) is the proportion of correctly "
"detected positive instances (TP / P).</p>"
"<p><b>Specificity</b> (rising) is the proportion of detected "
"negative instances (TP / N).</p>"),
("Precision and recall", _pr_curve, ("prec", "recall"),
"<p><b>Precision</b> (rising) is the fraction of retrieved instances "
"that are relevant, TP / (TP + FP).</p>"
"<p><b>Recall</b> (falling) is the proportion of discovered relevant "
"instances, TP / P.</p>"),
("Pos and neg predictive value", _ppv_npv_curve, ("PPV", "TPV"),
"<p><b>Positive predictive value</b> (rising) is the proportion of "
"correct positives, TP / (TP + FP).</p>"
"<p><b>Negative predictive value</b> is the proportion of correct "
"negatives, TN / (TN + FN).</p>"),
)]


def gaussian_smoother(x, y, sigma=1.0):
Expand Down
3 changes: 3 additions & 0 deletions Orange/widgets/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,9 @@ def __init__(self, master, enableDragDrop=False, dragDropCallback=None,
def sizeHint(self):
return self.size_hint

def minimumSizeHint(self):
return self.size_hint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because list view remained too large. This is the only way with which I managed to reduce its size - setMinimumSizeHint just didn't work...

def dragEnterEvent(self, event):
super().dragEnterEvent(event)
if self.valid_data_callback:
Expand Down