Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adaboost: adapt to scikit-learn's 1.4 deprecation of base_estimator #6637

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions Orange/ensembles/ada_boost.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import sklearn.ensemble as skl_ensemble

from Orange.base import SklLearner
Expand All @@ -7,6 +9,8 @@
from Orange.regression.base_regression import (
SklLearnerRegression, SklModelRegression
)
from Orange.util import OrangeDeprecationWarning


__all__ = ['SklAdaBoostClassificationLearner', 'SklAdaBoostRegressionLearner']

Expand All @@ -15,21 +19,32 @@
pass


def base_estimator_deprecation():
warnings.warn(
"`base_estimator` is deprecated (to be removed in 3.39): use `estimator` instead.",
OrangeDeprecationWarning, stacklevel=3)


class SklAdaBoostClassificationLearner(SklLearnerClassification):
__wraps__ = skl_ensemble.AdaBoostClassifier
__returns__ = SklAdaBoostClassifier
supports_weights = True

def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1.,
algorithm='SAMME.R', random_state=None, preprocessors=None):
def __init__(self, estimator=None, n_estimators=50, learning_rate=1.,
algorithm='SAMME.R', random_state=None, preprocessors=None,
base_estimator="deprecated"):
if base_estimator != "deprecated":
base_estimator_deprecation()
estimator = base_estimator
del base_estimator
from Orange.modelling import Fitter
# If fitter, get the appropriate Learner instance
if isinstance(base_estimator, Fitter):
base_estimator = base_estimator.get_learner(
base_estimator.CLASSIFICATION)
if isinstance(estimator, Fitter):
estimator = estimator.get_learner(
estimator.CLASSIFICATION)
# If sklearn learner, get the underlying sklearn representation
if isinstance(base_estimator, SklLearner):
base_estimator = base_estimator.__wraps__(**base_estimator.params)
if isinstance(estimator, SklLearner):
estimator = estimator.__wraps__(**estimator.params)
super().__init__(preprocessors=preprocessors)
self.params = vars()

Expand All @@ -43,15 +58,20 @@
__returns__ = SklAdaBoostRegressor
supports_weights = True

def __init__(self, base_estimator=None, n_estimators=50, learning_rate=1.,
loss='linear', random_state=None, preprocessors=None):
def __init__(self, estimator=None, n_estimators=50, learning_rate=1.,
loss='linear', random_state=None, preprocessors=None,
base_estimator="deprecated"):
if base_estimator != "deprecated":
base_estimator_deprecation()
estimator = base_estimator

Check warning on line 66 in Orange/ensembles/ada_boost.py

View check run for this annotation

Codecov / codecov/patch

Orange/ensembles/ada_boost.py#L65-L66

Added lines #L65 - L66 were not covered by tests
del base_estimator
from Orange.modelling import Fitter
# If fitter, get the appropriate Learner instance
if isinstance(base_estimator, Fitter):
base_estimator = base_estimator.get_learner(
base_estimator.REGRESSION)
if isinstance(estimator, Fitter):
estimator = estimator.get_learner(
estimator.REGRESSION)
# If sklearn learner, get the underlying sklearn representation
if isinstance(base_estimator, SklLearner):
base_estimator = base_estimator.__wraps__(**base_estimator.params)
if isinstance(estimator, SklLearner):
estimator = estimator.__wraps__(**estimator.params)
super().__init__(preprocessors=preprocessors)
self.params = vars()
29 changes: 23 additions & 6 deletions Orange/tests/test_ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
# pylint: disable=missing-docstring

import unittest
from distutils.version import LooseVersion

import numpy as np

import Orange
from Orange.data import Table
from Orange.classification import SklTreeLearner
from Orange.regression import SklTreeRegressionLearner
Expand All @@ -11,6 +15,7 @@
SklAdaBoostRegressionLearner,
)
from Orange.evaluation import CrossValidation, CA, RMSE
from Orange.util import OrangeDeprecationWarning


class TestSklAdaBoostLearner(unittest.TestCase):
Expand All @@ -27,14 +32,14 @@ def test_adaboost(self):
self.assertGreater(ca, 0.9)
self.assertLess(ca, 0.99)

def test_adaboost_base_estimator(self):
def test_adaboost_estimator(self):
np.random.seed(0)
stump_estimator = SklTreeLearner(max_depth=1)
tree_estimator = SklTreeLearner()
stump = SklAdaBoostClassificationLearner(
base_estimator=stump_estimator, n_estimators=5)
estimator=stump_estimator, n_estimators=5)
tree = SklAdaBoostClassificationLearner(
base_estimator=tree_estimator, n_estimators=5)
estimator=tree_estimator, n_estimators=5)
cv = CrossValidation(k=4)
results = cv(self.iris, [stump, tree])
ca = CA(results)
Expand Down Expand Up @@ -68,12 +73,12 @@ def test_adaboost_reg(self):
results = cv(self.housing, [learn])
_ = RMSE(results)

def test_adaboost_reg_base_estimator(self):
def test_adaboost_reg_estimator(self):
np.random.seed(0)
stump_estimator = SklTreeRegressionLearner(max_depth=1)
tree_estimator = SklTreeRegressionLearner()
stump = SklAdaBoostRegressionLearner(base_estimator=stump_estimator)
tree = SklAdaBoostRegressionLearner(base_estimator=tree_estimator)
stump = SklAdaBoostRegressionLearner(estimator=stump_estimator)
tree = SklAdaBoostRegressionLearner(estimator=tree_estimator)
cv = CrossValidation(k=3)
results = cv(self.housing, [stump, tree])
rmse = RMSE(results)
Expand Down Expand Up @@ -103,3 +108,15 @@ def test_predict_numpy_reg(self):
def test_adaboost_adequacy_reg(self):
learner = SklAdaBoostRegressionLearner()
self.assertRaises(ValueError, learner, self.iris)

def test_remove_deprecation(self):
if LooseVersion(Orange.__version__) >= LooseVersion("3.39"):
self.fail(
"`base_estimator` was deprecated in "
"version 3.37. Please remove everything related to it."
)
stump_estimator = SklTreeLearner(max_depth=1)
with self.assertWarns(OrangeDeprecationWarning):
SklAdaBoostClassificationLearner(base_estimator=stump_estimator)
with self.assertWarns(OrangeDeprecationWarning):
SklAdaBoostClassificationLearner(base_estimator=stump_estimator)
3 changes: 0 additions & 3 deletions Orange/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from unittest.mock import MagicMock

import numpy as np
from sklearn import __version__ as sklearn_version
from sklearn.utils import check_random_state

from Orange.data import Table, Domain
Expand Down Expand Up @@ -155,8 +154,6 @@ def test_improved_randomized_pca_sparse_data(self):
pca.singular_values_, rpca.singular_values_, decimal=8
)

@unittest.skipIf(sklearn_version.startswith('0.20'),
"https://github.com/scikit-learn/scikit-learn/issues/12234")
def test_incremental_pca(self):
data = self.ionosphere
self.__ipca_test_helper(data, n_com=3, min_xpl_var=0.49)
Expand Down
13 changes: 0 additions & 13 deletions Orange/widgets/evaluate/tests/test_owliftcurve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# pylint: disable=protected-access,duplicate-code
import copy
import pkg_resources
import unittest
from unittest.mock import Mock

import numpy as np
import sklearn

from AnyQt.QtGui import QFont, QPen

Expand All @@ -23,14 +21,6 @@
from Orange.tests import test_filename


# scikit-learn==1.1.1 does not support read the docs, therefore
# we can not make it a requirement for now. When the minimum required
# version is >=1.1.1, delete these exceptions.
OK_SKLEARN = pkg_resources.parse_version(sklearn.__version__) >= \
pkg_resources.parse_version("1.1.1")
SKIP_REASON = "Only test precision-recall with scikit-learn>=1.1.1"


class TestOWLiftCurve(EvaluateTest):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -304,7 +294,6 @@ def test_cumulative_gains_from_results():
assert_almost_equal(thresholds, [])

@staticmethod
@unittest.skipUnless(OK_SKLEARN, SKIP_REASON)
def test_precision_recall_from_results():
y_true = np.array([1, 0, 1, 0, 0, 1])
y_scores = np.array([0.6, 0.5, 0.9, 0.4, 0.2, 0.4])
Expand All @@ -324,7 +313,6 @@ def test_precision_recall_from_results():
np.array([0.2, 0.4, 0.5, 0.6, 0.9, 1]))

@staticmethod
@unittest.skipUnless(OK_SKLEARN, SKIP_REASON)
def test_precision_recall_from_results_one():
y_true = np.array([1, 0, 1, 0, 0, 1])
y_scores = np.array([0.6, 0.5, 1, 0.4, 0.2, 0.4])
Expand All @@ -344,7 +332,6 @@ def test_precision_recall_from_results_one():
np.array([0.2, 0.4, 0.5, 0.6, 1]))

@staticmethod
@unittest.skipUnless(OK_SKLEARN, SKIP_REASON)
def test_precision_recall_from_results_multiclass():
y_true = np.array([1, 0, 1, 0, 2, 2])
y_scores = np.array([[0.3, 0.3, 0.4],
Expand Down
2 changes: 1 addition & 1 deletion Orange/widgets/model/owadaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def create_learner(self):
if self.base_estimator is None:
return None
return self.LEARNER(
base_estimator=self.base_estimator,
estimator=self.base_estimator,
n_estimators=self.n_estimators,
learning_rate=self.learning_rate,
random_state=self.random_seed,
Expand Down
4 changes: 2 additions & 2 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ requirements:
- catboost >=1.0.1
- chardet >=3.0.2
- httpx >=0.21
- joblib >=1.0.0
- joblib >=1.1.1
- keyring
- keyrings.alt
- networkx
Expand All @@ -64,7 +64,7 @@ requirements:
- python-louvain >=0.13
- pyyaml
- requests
- scikit-learn >=1.1.0,!=1.2.*,<1.4 # ignoring 1.2.*: scikit-learn/issues/26241
- scikit-learn >=1.3.0
- scipy >=1.9
- serverfiles
- setuptools >=51.0.0
Expand Down
4 changes: 4 additions & 0 deletions i18n/si.jaml
Original file line number Diff line number Diff line change
Expand Up @@ -2010,14 +2010,18 @@ distance/distance.py:
def `compute_distances`:
hamming: false
ensembles/ada_boost.py:
def `base_estimator_deprecation`:
'`base_estimator` is deprecated (to be removed in 3.39): use `estimator` instead.': false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be false. I think you either provide a translation or leave it null for someone else to translate it later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's OK, this message is not intended to be seen by the user. It should be in English.

SklAdaBoostClassificationLearner: false
SklAdaBoostRegressionLearner: false
class `SklAdaBoostClassificationLearner`:
def `__init__`:
SAMME.R: false
deprecated: false
class `SklAdaBoostRegressionLearner`:
def `__init__`:
linear: false
deprecated: false
ensembles/stack.py:
StackedLearner: false
StackedClassificationLearner: false
Expand Down
4 changes: 2 additions & 2 deletions requirements-core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ catboost>=1.0.1
chardet>=3.0.2
httpx>=0.21.0
# Multiprocessing abstraction
joblib>=1.0.0
joblib>=1.1.1
keyring
keyrings.alt # for alternative keyring implementations
networkx
Expand All @@ -17,7 +17,7 @@ pip>=18.0
python-louvain>=0.13
pyyaml
requests
scikit-learn>=1.1.0,!=1.2.*,<1.4 # ignoring 1.2.*: scikit-learn/issues/26241
scikit-learn>=1.3.0
scipy>=1.9
serverfiles # for Data Sets synchronization
setuptools>=51.0.0
Expand Down
4 changes: 2 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ deps =
oldest: catboost==1.0.1
oldest: chardet==3.0.2
oldest: httpx==0.21.0
oldest: joblib==1.0.0
oldest: joblib==1.1.1
# oldest: keyring
# oldest: keyrings.alt
# oldest: networkx
Expand All @@ -62,7 +62,7 @@ deps =
oldest: python-louvain==0.13
# oldest: pyyaml
# oldest: requests
oldest: scikit-learn==1.1.0
oldest: scikit-learn==1.3.0
oldest: scipy==1.9
# oldest: serverfiles
oldest: setuptools==51.0.0
Expand Down
Loading