Skip to content

Commit

Permalink
OWCalibration Plot: Unit tests and some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Jun 18, 2019
1 parent 6ac1db1 commit 4cc3a54
Show file tree
Hide file tree
Showing 5 changed files with 581 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Orange/evaluation/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def set_or_raise(value, exp_values, msg):
"mismatching number of class values")
nmethods = set_or_raise(
nmethods, [learners is not None and len(learners),
models is not None and len(models),
models is not None and models.shape[1],
failed is not None and len(failed),
predicted is not None and predicted.shape[0],
probabilities is not None and probabilities.shape[0]],
Expand Down Expand Up @@ -365,7 +365,7 @@ def __new__(cls,
"and train_data are omitted")
return self

warn("calling Validation's constructor with data and learners"
warn("calling Validation's constructor with data and learners "
"is deprecated;\nconstruct an instance and call it",
DeprecationWarning, stacklevel=2)

Expand Down
2 changes: 1 addition & 1 deletion Orange/tests/test_evaluation_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def setUp(self):
self.row_indices = np.arange(100)
self.folds = (range(50), range(10, 60)), (range(50, 100), range(50))
self.learners = [MajorityLearner(), MajorityLearner()]
self.models = [Mock(), Mock()]
self.models = np.array([[Mock(), Mock()]])
self.predicted = np.zeros((2, 100))
self.probabilities = np.zeros((2, 100, 3))
self.failed = [False, True]
Expand Down
62 changes: 41 additions & 21 deletions Orange/widgets/evaluate/owcalibrationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from Orange.widgets import widget, gui, settings
from Orange.widgets.evaluate.contexthandlers import \
EvaluationResultsContextHandler
from Orange.widgets.evaluate.utils import \
check_results_adequacy, results_for_preview
from Orange.widgets.evaluate.utils import results_for_preview
from Orange.widgets.utils import colorpalette, colorbrewer
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import Input, Output, Msg
Expand Down Expand Up @@ -72,7 +71,8 @@ class Inputs:
class Outputs:
calibrated_model = Output("Calibrated Model", Model)

class Warning(widget.OWWidget.Warning):
class Error(widget.OWWidget.Error):
non_discrete_target = Msg("Calibration plot requires a discrete target")
empty_input = widget.Msg("Empty result on input. Nothing to display.")

class Information(widget.OWWidget.Information):
Expand All @@ -84,7 +84,8 @@ class Information(widget.OWWidget.Information):
"try testing on separate data or on training data")
no_output_multiple_selected = Msg(
no_out + "select a single model - the widget can output only one")
non_binary_class = Msg(no_out + "cannot calibrate non-binary classes")
no_output_non_binary_class = Msg(
no_out + "cannot calibrate non-binary classes")

settingsHandler = EvaluationResultsContextHandler()
target_index = settings.ContextSetting(0)
Expand Down Expand Up @@ -145,8 +146,8 @@ def __init__(self):
btnLabels=("Sigmoid calibration", "Isotonic calibration"),
label="Output model calibration", callback=self.apply)

box = gui.widgetBox(self.controlArea, "Info")
self.info_label = gui.widgetLabel(box)
self.info_box = gui.widgetBox(self.controlArea, "Info")
self.info_label = gui.widgetLabel(self.info_box)

gui.auto_commit(
self.controlArea, self, "auto_commit", "Apply", commit=self.apply)
Expand All @@ -159,6 +160,10 @@ def __init__(self):
for axis_name in ("bottom", "left"):
axis = self.plot.getAxis(axis_name)
axis.setPen(pg.mkPen(color=0.0))
# Remove the condition (that is, allow setting this for bottom
# axis) when pyqtgraph is fixed
# Issue: https://github.com/pyqtgraph/pyqtgraph/issues/930
# Pull request: https://github.com/pyqtgraph/pyqtgraph/pull/932
if axis_name != "bottom": # remove if when pyqtgraph is fixed
axis.setStyle(stopAxisAtTick=(True, True))

Expand All @@ -172,11 +177,14 @@ def __init__(self):
def set_results(self, results):
self.closeContext()
self.clear()
results = check_results_adequacy(results, self.Error, check_nan=False)
self.Error.clear()
self.Information.clear()
if results is not None and not results.domain.has_discrete_class:
self.Error.non_discrete_target()
results = None
if results is not None and not results.actual.size:
self.Warning.empty_input()
else:
self.Warning.empty_input.clear()
self.Error.empty_input()
results = None
self.results = results
if self.results is not None:
self._initialize(results)
Expand Down Expand Up @@ -219,8 +227,10 @@ def _set_explanation(self):

if self.score == 0:
self.controls.output_calibration.show()
self.info_box.hide()
else:
self.controls.output_calibration.hide()
self.info_box.show()

axis = self.plot.getAxis("bottom")
axis.setLabel("Predicted probability" if self.score == 0
Expand All @@ -230,23 +240,23 @@ def _set_explanation(self):
axis.setLabel(Metrics[self.score].name)

def _initialize(self, results):
N = len(results.predicted)
n = len(results.predicted)
names = getattr(results, "learner_names", None)
if names is None:
names = ["#{}".format(i + 1) for i in range(N)]
names = ["#{}".format(i + 1) for i in range(n)]

self.classifier_names = names
scheme = colorbrewer.colorSchemes["qualitative"]["Dark2"]
if N > len(scheme):
if n > len(scheme):
scheme = colorpalette.DefaultRGBColors
self.colors = colorpalette.ColorPaletteGenerator(N, scheme)
self.colors = colorpalette.ColorPaletteGenerator(n, scheme)

for i in range(N):
for i in range(n):
item = self.classifiers_list_box.item(i)
item.setIcon(colorpalette.ColorPixmap(self.colors[i]))

self.selected_classifiers = list(range(N))
self.target_cb.addItems(results.data.domain.class_var.values)
self.selected_classifiers = list(range(n))
self.target_cb.addItems(results.domain.class_var.values)

def _rug(self, data, pen_args):
color = pen_args["pen"].color()
Expand Down Expand Up @@ -288,7 +298,6 @@ def _prob_curve(self, ytrue, probs, pen_args):
y = np.full(100, xmax)

self.plot.plot(x, y, symbol="+", symbolSize=4, **pen_args)
self.plot.plot([0, 1], [0, 1], antialias=True)
return x, (y, )

def _setup_plot(self):
Expand Down Expand Up @@ -326,6 +335,9 @@ def _setup_plot(self):
self.plot_metrics(Curves(fold_ytrue, fold_probs),
metrics, pen_args)

if self.score == 0:
self.plot.plot([0, 1], [0, 1], antialias=True)

def _replot(self):
self.plot.clear()
if self.results is not None:
Expand Down Expand Up @@ -379,7 +391,7 @@ def _update_info(self):
for curve in curves)
text += "</tr>"
text += "<table>"
self.info_label.setText(text)
self.info_label.setText(text)

def threshold_change_done(self):
self.apply()
Expand All @@ -395,7 +407,7 @@ def apply(self):
info.no_output_no_models: results.models is None,
info.no_output_multiple_selected:
len(self.selected_classifiers) != 1,
info.non_binary_class:
info.no_output_non_binary_class:
self.score != 0
and len(results.domain.class_var.values) != 2}
if not any(problems.values()):
Expand All @@ -419,11 +431,19 @@ def apply(self):
def send_report(self):
if self.results is None:
return
self.report_items((
("Target class", self.target_cb.currentText()),
("Output model calibration",
self.score == 0 and self.controls.score.currentText()),
))
caption = report.list_legend(self.classifiers_list_box,
self.selected_classifiers)
self.report_items((("Target class", self.target_cb.currentText()),))
self.report_plot()
self.report_caption(caption)
self.report_caption(self.controls.score.currentText())

if self.score != 0:
self.report_raw(self.info_label.text())


def gaussian_smoother(x, y, sigma=1.0):
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/evaluate/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def test_many_evaluation_results(self):
classification.NaiveBayesLearner(),
classification.SGDClassificationLearner()
]
res = evaluation.CrossValidation(data, learners, k=2, store_data=True)
res = evaluation.CrossValidation(k=2, store_data=True)(data, learners)
# this is a mixin; pylint: disable=no-member
self.send_signal("Evaluation Results", res)
Loading

0 comments on commit 4cc3a54

Please sign in to comment.