Skip to content

Commit

Permalink
Merge pull request #6956 from janezd/logistic-no-multiclass
Browse files Browse the repository at this point in the history
LogisticRegressionLearner: Remove deprecated argument 'multi_class'
  • Loading branch information
janezd authored Jan 10, 2025
2 parents 4912a63 + 6a7efa1 commit 876acb2
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 16 deletions.
19 changes: 17 additions & 2 deletions Orange/classification/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import warnings

import numpy as np
import sklearn.linear_model as skl_linear_model

from Orange.classification import SklLearner, SklModel
from Orange.preprocess import Normalize
from Orange.preprocess.score import LearnerScorer
from Orange.data import Variable, DiscreteVariable
from Orange.util import OrangeDeprecationWarning


__all__ = ["LogisticRegressionLearner"]

Expand Down Expand Up @@ -38,18 +42,29 @@ class LogisticRegressionLearner(SklLearner, _FeatureScorerMixin):
def __init__(self, penalty="l2", dual=False, tol=0.0001, C=1.0,
fit_intercept=True, intercept_scaling=1, class_weight=None,
random_state=None, solver="auto", max_iter=100,
multi_class="auto", verbose=0, n_jobs=1, preprocessors=None):
multi_class="deprecated", verbose=0, n_jobs=1, preprocessors=None):
if multi_class != "deprecated":
warnings.warn("The multi_class parameter was "
"deprecated in scikit-learn 1.5. Using it with "
"scikit-learn 1.7 will lead to a crash.",
OrangeDeprecationWarning,
stacklevel=2)
super().__init__(preprocessors=preprocessors)
self.params = vars()

def _initialize_wrapped(self):
params = self.params.copy()

multi_class = params.pop("multi_class")
if multi_class != "deprecated":
params["multi_class"] = multi_class

# The default scikit-learn solver `lbfgs` (v0.22) does not support the
# l1 penalty.
solver, penalty = params.pop("solver"), params.get("penalty")
if solver == "auto":
if penalty == "l1":
solver = "liblinear"
solver = "saga"
else:
solver = "lbfgs"
params["solver"] = solver
Expand Down
4 changes: 2 additions & 2 deletions Orange/evaluation/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
--------
>>> import Orange
>>> data = Orange.data.Table('iris')
>>> learner = Orange.classification.LogisticRegressionLearner(solver="liblinear")
>>> learner = Orange.classification.LogisticRegressionLearner()
>>> results = Orange.evaluation.TestOnTrainingData(data, [learner])
"""
Expand Down Expand Up @@ -296,7 +296,7 @@ class LogLoss(ClassificationScore):
Examples
--------
>>> Orange.evaluation.LogLoss(results)
array([0.3...])
array([0.1...])
"""
__wraps__ = skl_metrics.log_loss
Expand Down
12 changes: 11 additions & 1 deletion Orange/tests/test_logistic_regression.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring

from datetime import datetime
import unittest

import numpy as np
Expand All @@ -9,6 +10,7 @@
from Orange.data import Table, ContinuousVariable, Domain
from Orange.classification import LogisticRegressionLearner, Model
from Orange.evaluation import CrossValidation, CA
from Orange.util import OrangeDeprecationWarning


class TestLogisticRegressionLearner(unittest.TestCase):
Expand Down Expand Up @@ -149,8 +151,16 @@ def test_auto_solver(self):
# liblinear is default for l2 penalty
lr = LogisticRegressionLearner(penalty="l1", solver="auto")
skl_clf = lr._initialize_wrapped()
self.assertEqual(skl_clf.solver, "liblinear")
self.assertEqual(skl_clf.solver, "saga")
self.assertEqual(skl_clf.penalty, "l1")

def test_supports_weights(self):
self.assertTrue(LogisticRegressionLearner().supports_weights)

def test_multi_class_deprecation(self):
with self.assertWarns(OrangeDeprecationWarning):
LogisticRegressionLearner(penalty="l1", multi_class="multinomial")
now = datetime.now()
if (now.year, now.month) >= (2026, 1):
raise Exception("If Orange depends on scikit-learn >= 1.7, remove this test "
"and any mention of multi_class in LogisticRegressionLearner.")
4 changes: 2 additions & 2 deletions Orange/widgets/evaluate/tests/test_owpredictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def test_output_error_cls(self):
log_reg = LogisticRegressionLearner()
self.send_signal(self.widget.Inputs.predictors, log_reg(data), 0)
self.send_signal(self.widget.Inputs.predictors,
LogisticRegressionLearner(penalty="l1")(data), 1)
LogisticRegressionLearner(penalty="l1", max_iter=1000)(data), 1)
with data.unlocked(data.Y):
data.Y[1] = np.nan
self.send_signal(self.widget.Inputs.data, data)
Expand All @@ -1316,7 +1316,7 @@ def test_output_error_cls(self):
names = [f"{log_reg.name}{x}" for x in names]
self.assertEqual(names, [m.name for m in pred.domain.metas])
self.assertAlmostEqual(pred.metas[0, 4], 0.018, 3)
self.assertAlmostEqual(pred.metas[0, 9], 0.113, 3)
self.assertAlmostEqual(pred.metas[0, 9], 0.008, 3)
self.assertTrue(np.isnan(pred.metas[1, 4]))
self.assertTrue(np.isnan(pred.metas[1, 9]))

Expand Down
6 changes: 2 additions & 4 deletions Orange/widgets/visualize/tests/test_ownomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,8 @@ def test_nomogram_nb_multiclass(self):
def test_nomogram_lr_multiclass(self):
"""Check probabilities for logistic regression classifier for various
values of classes and radio buttons for multiclass data"""
cls = LogisticRegressionLearner(
multi_class="ovr", solver="liblinear"
)(self.lenses)
self._test_helper(cls, [9, 45, 52])
cls = LogisticRegressionLearner(max_iter=100)(self.lenses)
self._test_helper(cls, [18, 56, 78])

def test_nomogram_with_instance_nb(self):
"""Check initialized marker values and feature sorting for naive bayes
Expand Down
12 changes: 7 additions & 5 deletions i18n/si/msgs.jaml
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,18 @@ classification/logistic_regression.py:
def `__init__`:
l2: false
auto: false
deprecated: false
'The multi_class parameter was ': false
'deprecated in scikit-learn 1.5. Using it with ': false
scikit-learn 1.7 will lead to a crash.: false
def `_initialize_wrapped`:
multi_class: false
deprecated: false
solver: false
penalty: false
auto: false
l1: false
liblinear: false
saga: false
lbfgs: false
classification/majority.py:
MajorityLearner: false
Expand Down Expand Up @@ -3096,9 +3102,6 @@ projection/pca.py:
auto: false
def `fit`:
n_components: false
svd_solver: false
auto: false
arpack: false
class `SparsePCA`:
Sparse PCA: false
def `__init__`:
Expand Down Expand Up @@ -15437,7 +15440,6 @@ widgets/visualize/utils/error_bars_dialog.py:
Upper:: Zgornje:
Lower:: Spodnje:
__main__: false
Error Bars: false
Open: false
iris: false
widgets/visualize/utils/heatmap.py:
Expand Down

0 comments on commit 876acb2

Please sign in to comment.