From 1231f472ed693c6dc12ea50698d8415c4b68c8ec Mon Sep 17 00:00:00 2001 From: Marko Toplak Date: Fri, 17 Nov 2023 13:22:26 +0100 Subject: [PATCH] adaboost: adapt to scikit-learn's 1.4 deprecation of base_estimator --- Orange/ensembles/ada_boost.py | 48 +++++++++++++++++++++--------- Orange/tests/test_ada_boost.py | 23 ++++++++++---- Orange/widgets/model/owadaboost.py | 2 +- 3 files changed, 52 insertions(+), 21 deletions(-) diff --git a/Orange/ensembles/ada_boost.py b/Orange/ensembles/ada_boost.py index f0a5ef6ebf6..1d3c689aa42 100644 --- a/Orange/ensembles/ada_boost.py +++ b/Orange/ensembles/ada_boost.py @@ -1,3 +1,5 @@ +import warnings + import sklearn.ensemble as skl_ensemble from Orange.base import SklLearner @@ -7,6 +9,8 @@ from Orange.regression.base_regression import ( SklLearnerRegression, SklModelRegression ) +from Orange.util import OrangeDeprecationWarning + __all__ = ['SklAdaBoostClassificationLearner', 'SklAdaBoostRegressionLearner'] @@ -15,21 +19,32 @@ class SklAdaBoostClassifier(SklModelClassification): pass +def base_estimator_deprecation(): + warnings.warn( + "`base_estimator` is deprecated: use `estimator` instead.", + OrangeDeprecationWarning, stacklevel=3) + + class SklAdaBoostClassificationLearner(SklLearnerClassification): __wraps__ = skl_ensemble.AdaBoostClassifier __returns__ = SklAdaBoostClassifier supports_weights = True - def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1., - algorithm='SAMME.R', random_state=None, preprocessors=None): + def __init__(self, estimator=None, n_estimators=50, learning_rate=1., + algorithm='SAMME.R', random_state=None, preprocessors=None, + base_estimator="deprecated"): + if base_estimator != "deprecated": + base_estimator_deprecation() + estimator = base_estimator + del base_estimator from Orange.modelling import Fitter # If fitter, get the appropriate Learner instance - if isinstance(base_estimator, Fitter): - base_estimator = base_estimator.get_learner( - base_estimator.CLASSIFICATION) + if isinstance(estimator, Fitter): + estimator = estimator.get_learner( + estimator.CLASSIFICATION) # If sklearn learner, get the underlying sklearn representation - if isinstance(base_estimator, SklLearner): - base_estimator = base_estimator.__wraps__(**base_estimator.params) + if isinstance(estimator, SklLearner): + estimator = estimator.__wraps__(**estimator.params) super().__init__(preprocessors=preprocessors) self.params = vars() @@ -43,15 +58,20 @@ class SklAdaBoostRegressionLearner(SklLearnerRegression): __returns__ = SklAdaBoostRegressor supports_weights = True - def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1., - loss='linear', random_state=None, preprocessors=None): + def __init__(self, estimator=None, n_estimators=50, learning_rate=1., + loss='linear', random_state=None, preprocessors=None, + base_estimator="deprecated"): + if base_estimator != "deprecated": + base_estimator_deprecation() + estimator = base_estimator + del base_estimator from Orange.modelling import Fitter # If fitter, get the appropriate Learner instance - if isinstance(base_estimator, Fitter): - base_estimator = base_estimator.get_learner( - base_estimator.REGRESSION) + if isinstance(estimator, Fitter): + estimator = estimator.get_learner( + estimator.REGRESSION) # If sklearn learner, get the underlying sklearn representation - if isinstance(base_estimator, SklLearner): - base_estimator = base_estimator.__wraps__(**base_estimator.params) + if isinstance(estimator, SklLearner): + estimator = estimator.__wraps__(**estimator.params) super().__init__(preprocessors=preprocessors) self.params = vars() diff --git a/Orange/tests/test_ada_boost.py b/Orange/tests/test_ada_boost.py index c10f3af63b9..379f26c1594 100644 --- a/Orange/tests/test_ada_boost.py +++ b/Orange/tests/test_ada_boost.py @@ -2,7 +2,11 @@ # pylint: disable=missing-docstring import unittest +from distutils.version import LooseVersion + import numpy as np + +import Orange from Orange.data import Table from Orange.classification import SklTreeLearner from Orange.regression import SklTreeRegressionLearner @@ -27,14 +31,14 @@ def test_adaboost(self): self.assertGreater(ca, 0.9) self.assertLess(ca, 0.99) - def test_adaboost_base_estimator(self): + def test_adaboost_estimator(self): np.random.seed(0) stump_estimator = SklTreeLearner(max_depth=1) tree_estimator = SklTreeLearner() stump = SklAdaBoostClassificationLearner( - base_estimator=stump_estimator, n_estimators=5) + estimator=stump_estimator, n_estimators=5) tree = SklAdaBoostClassificationLearner( - base_estimator=tree_estimator, n_estimators=5) + estimator=tree_estimator, n_estimators=5) cv = CrossValidation(k=4) results = cv(self.iris, [stump, tree]) ca = CA(results) @@ -68,12 +72,12 @@ def test_adaboost_reg(self): results = cv(self.housing, [learn]) _ = RMSE(results) - def test_adaboost_reg_base_estimator(self): + def test_adaboost_reg_estimator(self): np.random.seed(0) stump_estimator = SklTreeRegressionLearner(max_depth=1) tree_estimator = SklTreeRegressionLearner() - stump = SklAdaBoostRegressionLearner(base_estimator=stump_estimator) - tree = SklAdaBoostRegressionLearner(base_estimator=tree_estimator) + stump = SklAdaBoostRegressionLearner(estimator=stump_estimator) + tree = SklAdaBoostRegressionLearner(estimator=tree_estimator) cv = CrossValidation(k=3) results = cv(self.housing, [stump, tree]) rmse = RMSE(results) @@ -103,3 +107,10 @@ def test_predict_numpy_reg(self): def test_adaboost_adequacy_reg(self): learner = SklAdaBoostRegressionLearner() self.assertRaises(ValueError, learner, self.iris) + + def test_remove_deprecation(self): + if LooseVersion(Orange.__version__) >= LooseVersion("3.39"): + self.fail( + "`base_estimator` was deprecated in " + "version 3.37. Please remove everything related to it." + ) diff --git a/Orange/widgets/model/owadaboost.py b/Orange/widgets/model/owadaboost.py index 785938e1d74..12836889b0c 100644 --- a/Orange/widgets/model/owadaboost.py +++ b/Orange/widgets/model/owadaboost.py @@ -80,7 +80,7 @@ def create_learner(self): if self.base_estimator is None: return None return self.LEARNER( - base_estimator=self.base_estimator, + estimator=self.base_estimator, n_estimators=self.n_estimators, learning_rate=self.learning_rate, random_state=self.random_seed,