-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from Orange.classification import CalibratedLearner, ThresholdLearner, \ | ||
NaiveBayesLearner | ||
from Orange.data import Table | ||
from Orange.modelling import Learner | ||
from Orange.widgets import gui | ||
from Orange.widgets.widget import Input | ||
from Orange.widgets.settings import Setting | ||
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner | ||
from Orange.widgets.utils.widgetpreview import WidgetPreview | ||
|
||
|
||
class OWCalibratedLearner(OWBaseLearner): | ||
name = "Calibrated Learner" | ||
description = "Wraps another learner with probability calibration and " \ | ||
"decision threshold optimization" | ||
icon = "icons/CalibratedLearner.svg" | ||
priority = 20 | ||
keywords = ["calibration", "threshold"] | ||
|
||
LEARNER = CalibratedLearner | ||
|
||
SigmoidCalibration, IsotonicCalibration, NoCalibration = range(3) | ||
CalibrationOptions = ("Sigmoid calibration", | ||
"Isotonic calibration", | ||
"No calibration") | ||
CalibrationShort = ("Sigmoid", "Isotonic", "") | ||
CalibrationMap = { | ||
SigmoidCalibration: CalibratedLearner.Sigmoid, | ||
IsotonicCalibration: CalibratedLearner.Isotonic} | ||
|
||
OptimizeCA, OptimizeF1, NoThresholdOptimization = range(3) | ||
ThresholdOptions = ("Optimize classification accuracy", | ||
"Optimize F1 score", | ||
"No threshold optimization") | ||
ThresholdShort = ("CA", "F1", "") | ||
ThresholdMap = { | ||
OptimizeCA: ThresholdLearner.OptimizeCA, | ||
OptimizeF1: ThresholdLearner.OptimizeF1} | ||
|
||
learner_name = Setting("", schema_only=True) | ||
calibration = Setting(SigmoidCalibration) | ||
threshold = Setting(OptimizeCA) | ||
|
||
class Inputs(OWBaseLearner.Inputs): | ||
base_learner = Input("Base Learner", Learner) | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.base_learner = None | ||
|
||
def add_main_layout(self): | ||
gui.radioButtons( | ||
self.controlArea, self, "calibration", self.CalibrationOptions, | ||
box="Probability calibration", | ||
callback=self.calibration_options_changed) | ||
gui.radioButtons( | ||
self.controlArea, self, "threshold", self.ThresholdOptions, | ||
box="Decision threshold optimization", | ||
callback=self.calibration_options_changed) | ||
|
||
@Inputs.base_learner | ||
def set_learner(self, learner): | ||
self.base_learner = learner | ||
self._set_default_name() | ||
self.unconditional_apply() | ||
|
||
def _set_default_name(self): | ||
if self.base_learner is None: | ||
self.name = "Calibrated learner" | ||
else: | ||
self.name = " + ".join(part for part in ( | ||
self.base_learner.name.title(), | ||
self.CalibrationShort[self.calibration], | ||
self.ThresholdShort[self.threshold]) if part) | ||
self.controls.learner_name.setPlaceholderText(self.name) | ||
|
||
def calibration_options_changed(self): | ||
self._set_default_name() | ||
self.apply() | ||
|
||
def create_learner(self): | ||
class IdentityWrapper(Learner): | ||
def fit_storage(self, data): | ||
return self.base_learner.fit_storage(data) | ||
|
||
if self.base_learner is None: | ||
return None | ||
learner = self.base_learner | ||
if self.calibration != self.NoCalibration: | ||
learner = CalibratedLearner(learner, | ||
self.CalibrationMap[self.calibration]) | ||
if self.threshold != self.NoThresholdOptimization: | ||
learner = ThresholdLearner(learner, | ||
self.ThresholdMap[self.threshold]) | ||
if self.preprocessors: | ||
if learner is self.base_learner: | ||
learner = IdentityWrapper() | ||
learner.preprocessors = (self.preprocessors, ) | ||
return learner | ||
|
||
def get_learner_parameters(self): | ||
return (("Calibrate probabilities", | ||
self.CalibrationOptions[self.calibrate]), | ||
("Threshold optimization", | ||
self.ThresholdOptions[self.threshold])) | ||
|
||
|
||
if __name__ == "__main__": # pragma: no cover | ||
WidgetPreview(OWCalibratedLearner).run( | ||
Table("heart_disease"), | ||
set_learner=NaiveBayesLearner()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from unittest.mock import Mock | ||
|
||
from Orange.classification import ThresholdLearner, CalibratedLearner, \ | ||
NaiveBayesLearner, ThresholdClassifier, CalibratedClassifier | ||
from Orange.classification.base_classification import ModelClassification, \ | ||
LearnerClassification | ||
from Orange.classification.naive_bayes import NaiveBayesModel | ||
from Orange.data import Table | ||
from Orange.widgets.model.owcalibratedlearner import OWCalibratedLearner | ||
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin, \ | ||
datasets | ||
|
||
|
||
class TestOWCalibratedLearner(WidgetTest, WidgetLearnerTestMixin): | ||
def setUp(self): | ||
self.widget = self.create_widget( | ||
OWCalibratedLearner, stored_settings={"auto_apply": False}) | ||
self.send_signal(self.widget.Inputs.base_learner, NaiveBayesLearner()) | ||
|
||
self.data = Table("heart_disease") | ||
self.valid_datasets = (self.data,) | ||
self.inadequate_dataset = (Table(datasets.path("testing_dataset_reg")),) | ||
self.learner_class = LearnerClassification | ||
self.model_class = ModelClassification | ||
self.model_name = 'Calibrated classifier' | ||
self.parameters = [] | ||
|
||
def test_output_learner(self): | ||
"""Check if learner is on output after apply""" | ||
# Overridden to change the output type in the last test | ||
initial = self.get_output("Learner") | ||
self.assertIsNotNone(initial, "Does not initialize the learner output") | ||
self.widget.apply_button.button.click() | ||
newlearner = self.get_output("Learner") | ||
self.assertIsNot(initial, newlearner, | ||
"Does not send a new learner instance on `Apply`.") | ||
self.assertIsNotNone(newlearner) | ||
self.assertIsInstance( | ||
newlearner, | ||
(CalibratedLearner, ThresholdLearner, NaiveBayesLearner)) | ||
|
||
def test_output_model(self): | ||
"""Check if model is on output after sending data and apply""" | ||
# Overridden to change the output type in the last two test | ||
self.assertIsNone(self.get_output(self.widget.Outputs.model)) | ||
self.widget.apply_button.button.click() | ||
self.assertIsNone(self.get_output(self.widget.Outputs.model)) | ||
self.send_signal('Data', self.data) | ||
self.widget.apply_button.button.click() | ||
self.wait_until_stop_blocking() | ||
model = self.get_output(self.widget.Outputs.model) | ||
self.assertIsNotNone(model) | ||
self.assertIsInstance( | ||
model, (CalibratedClassifier, ThresholdClassifier, NaiveBayesModel)) | ||
|
||
def test_create_learner(self): | ||
widget = self.widget #: OWCalibratedLearner | ||
self.widget.base_learner = Mock() | ||
|
||
widget.calibration = widget.SigmoidCalibration | ||
widget.threshold = widget.OptimizeF1 | ||
learner = self.widget.create_learner() | ||
self.assertIsInstance(learner, ThresholdLearner) | ||
self.assertEqual(learner.threshold_criterion, learner.OptimizeF1) | ||
cal_learner = learner.base_learner | ||
self.assertIsInstance(cal_learner, CalibratedLearner) | ||
self.assertEqual(cal_learner.calibration_method, cal_learner.Sigmoid) | ||
self.assertIs(cal_learner.base_learner, self.widget.base_learner) | ||
|
||
widget.calibration = widget.IsotonicCalibration | ||
widget.threshold = widget.OptimizeCA | ||
learner = self.widget.create_learner() | ||
self.assertIsInstance(learner, ThresholdLearner) | ||
self.assertEqual(learner.threshold_criterion, learner.OptimizeCA) | ||
cal_learner = learner.base_learner | ||
self.assertIsInstance(cal_learner, CalibratedLearner) | ||
self.assertEqual(cal_learner.calibration_method, cal_learner.Isotonic) | ||
self.assertIs(cal_learner.base_learner, self.widget.base_learner) | ||
|
||
widget.calibration = widget.NoCalibration | ||
widget.threshold = widget.OptimizeCA | ||
learner = self.widget.create_learner() | ||
self.assertIsInstance(learner, ThresholdLearner) | ||
self.assertEqual(learner.threshold_criterion, learner.OptimizeCA) | ||
self.assertIs(learner.base_learner, self.widget.base_learner) | ||
|
||
widget.calibration = widget.IsotonicCalibration | ||
widget.threshold = widget.NoThresholdOptimization | ||
learner = self.widget.create_learner() | ||
self.assertIsInstance(learner, CalibratedLearner) | ||
self.assertEqual(learner.calibration_method, cal_learner.Isotonic) | ||
self.assertIs(learner.base_learner, self.widget.base_learner) | ||
|
||
widget.calibration = widget.NoCalibration | ||
widget.threshold = widget.NoThresholdOptimization | ||
learner = self.widget.create_learner() | ||
self.assertIs(learner, self.widget.base_learner) | ||
|
||
widget.calibration = widget.SigmoidCalibration | ||
widget.threshold = widget.OptimizeF1 | ||
widget.base_learner = None | ||
learner = self.widget.create_learner() | ||
self.assertIsNone(learner) | ||
|
||
def test_preprocessors(self): | ||
widget = self.widget #: OWCalibratedLearner | ||
self.widget.base_learner = Mock() | ||
self.widget.base_learner.preprocessors = () | ||
|
||
widget.calibration = widget.SigmoidCalibration | ||
widget.threshold = widget.OptimizeF1 | ||
widget.preprocessors = Mock() | ||
learner = self.widget.create_learner() | ||
self.assertEqual(learner.preprocessors, (widget.preprocessors, )) | ||
self.assertEqual(learner.base_learner.preprocessors, ()) | ||
self.assertEqual(learner.base_learner.base_learner.preprocessors, ()) | ||
|
||
widget.calibration = widget.NoCalibration | ||
widget.threshold = widget.NoThresholdOptimization | ||
learner = self.widget.create_learner() | ||
self.assertIsNot(learner, self.widget.base_learner) | ||
self.assertFalse( | ||
isinstance(learner, (CalibratedLearner, ThresholdLearner))) | ||
self.assertEqual(learner.preprocessors, (widget.preprocessors, )) | ||
|
||
def test_set_learner_calls_unconditional_apply(self): | ||
widget = self.widget | ||
self.assertIsNotNone(self.get_output(widget.Outputs.learner)) | ||
|
||
widget.auto_apply = False | ||
self.send_signal(widget.Inputs.base_learner, None) | ||
self.assertIsNone(self.get_output(widget.Outputs.learner)) | ||
|
||
def test_name_changes(self): | ||
widget = self.widget | ||
widget.auto_apply = True | ||
learner = NaiveBayesLearner() | ||
learner.name = "foo" | ||
self.send_signal(widget.Inputs.base_learner, learner) | ||
|
||
widget.calibration = widget.IsotonicCalibration | ||
widget.threshold = widget.OptimizeCA | ||
widget.controls.calibration.group.buttonClicked[int].emit( | ||
widget.IsotonicCalibration) | ||
|
||
learner = self.get_output(widget.Outputs.learner) | ||
self.assertEqual(learner.name, "Foo + Isotonic + CA") | ||
|
||
widget.calibration = widget.NoCalibration | ||
widget.threshold = widget.OptimizeCA | ||
widget.controls.calibration.group.buttonClicked[int].emit( | ||
widget.NoCalibration) | ||
learner = self.get_output(widget.Outputs.learner) | ||
self.assertEqual(learner.name, "Foo + CA") | ||
|
||
self.send_signal(widget.Inputs.base_learner, None) | ||
self.assertEqual(widget.controls.learner_name.placeholderText(), | ||
"Calibrated learner") |