Skip to content

Commit

Permalink
Merge pull request #6917 from ales-erjavec/fixes/owcalibratedlearner-…
Browse files Browse the repository at this point in the history
…base-learner-modify

[FIX] Calibrated Learner: Prevent in place modification of base learner
  • Loading branch information
janezd authored Oct 25, 2024
2 parents e42b9b1 + 4e7e541 commit c27a803
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
12 changes: 5 additions & 7 deletions Orange/widgets/model/owcalibratedlearner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

from Orange.classification import CalibratedLearner, ThresholdLearner, \
NaiveBayesLearner
from Orange.data import Table
Expand Down Expand Up @@ -65,7 +67,6 @@ def set_learner(self, learner):
self.learner = self.model = None

def _set_default_name(self):

if self.base_learner is None:
self.set_default_learner_name("")
else:
Expand All @@ -80,10 +81,6 @@ def calibration_options_changed(self):
self.apply()

def create_learner(self):
class IdentityWrapper(Learner):
def fit_storage(self, data):
return self.base_learner.fit_storage(data)

if self.base_learner is None:
return None
learner = self.base_learner
Expand All @@ -93,10 +90,11 @@ def fit_storage(self, data):
if self.threshold != self.NoThresholdOptimization:
learner = ThresholdLearner(learner,
self.ThresholdMap[self.threshold])
if learner is self.base_learner:
learner = copy.deepcopy(learner)
if self.preprocessors:
if learner is self.base_learner:
learner = IdentityWrapper()
learner.preprocessors = (self.preprocessors, )
assert learner is not self.base_learner
return learner

def get_learner_parameters(self):
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/model/tests/test_owcalibratedlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_create_learner(self):
widget.calibration = widget.NoCalibration
widget.threshold = widget.NoThresholdOptimization
learner = self.widget.create_learner()
self.assertIs(learner, self.widget.base_learner)
self.assertIsNot(learner, self.widget.base_learner)

widget.calibration = widget.SigmoidCalibration
widget.threshold = widget.OptimizeF1
Expand Down

0 comments on commit c27a803

Please sign in to comment.