Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add robust metric #122

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ Robust
robust.RobustWeightedClassifier
robust.RobustWeightedRegressor
robust.RobustWeightedKMeans

.. autosummary::
:toctree: generated/
:template: function.rst

robust.make_huber_metric
2 changes: 2 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Changelog
Unreleased
----------

- Add `make_huber_metric` which transform a non-robust to a robust metric using
Huber estimator.
- Add a stopping criterion and parameter tuning heuristic for Huber robust mean
estimator.
- Add `CLARA` (Clustering for Large Applications) which extends k-medoids to
Expand Down
43 changes: 43 additions & 0 deletions doc/modules/robust.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,45 @@ This algorithm has been studied in the context of "mom" weights in the
article [1]_, the context of "huber" weights has been mentioned in [2]_.
Both weighting schemes can be seen as special cases of the algorithm in [3]_.


Robust model selection
----------------------
.. _make_huber_metric:

One of the big challenge of robust machine learning is that the usual scoring
scheme (cross_validation with mean squared error for instance) is not robust.
Indeed, if the dataset has some outliers, then the test sets in cross-validation
may have outliers and then the cross_validation MSE would give us a huge error
for our robust algorithm on any corrupted data.

To solve this problem, one can use robust score methods when doing
cross-validation using `make_huber_metric`. See the following example:

:ref:`../auto_examples/robust/plot_robust_cv_example.html`

This type of robust cross-validation was mentioned for instance in [4]_.


Here is what `make_huber_metric` computes: suppose that we compute a
loss function as such:

.. math::

\widehat L = \frac{1}{n}\sum_{i=1}^n \ell(Y_i, f(X_i))

`make_huber_metric` propose to change this computation for

.. math::
\widehat L_{rob}=\widehat{\mathrm{Hub}}\left(\ell(Y_i, f(X_i))\right)

where :math:`\widehat{\mathrm{Hub}}` is the Huber estimator of location. It is a
robust estimator of the mean (similar result can also be attained using the
trimmed mean), and :math:`\widehat{L}_{rob}` is robust in the sense
that an especially large value of :math:`\ell(Y_i, f(X_i))` would not change the
value of the result by a lot. The constant `c` used when tuning
:math:`\widehat{\mathrm{Hub}}` has the same role of tuning the robustness as in
the case of regression and classification using Huber weights.

Comparison with other robust estimators
---------------------------------------

Expand Down Expand Up @@ -203,3 +242,7 @@ the example with California housing real dataset, for further discussion.
.. [3] Stanislav Minsker and Timothée Mathieu.
`"Excess risk bounds in robust empirical risk minimization" <https://arxiv.org/abs/1910.07485>`_
arXiv preprint (2019). arXiv:1910.07485.

.. [4] Elvezio Ronchetti , Christopher Field & Wade Blanchard
`" Robust Linear Model Selection by Cross-Validation" <https://www.tandfonline.com/doi/abs/10.1080/01621459.1997.10474057>_
Journal of the American Statistical Association (1995).
62 changes: 62 additions & 0 deletions examples/robust/plot_robust_cv_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""
================================================================
An example of a robust cross-validation evaluation in regression
================================================================
In this example we compare `LinearRegression` (OLS) with `HuberRegressor` from
scikit-learn using cross-validation.

We show that a robust cross-validation scheme gives a better
evaluation of the generalisation error in a corrupted dataset.

In this example, we do robust cross-validation by using an alternative to the
empirical mean to aggregate the errors. This alternative is a robust estimator
of the mean (the trimmed mean is an example of such a robust estimator, but here
we use Huber's estimator). This robust estimator of the mean is used on each
fold of the cross-validation and then, we return the empirical mean of the
obtained robust scores to get the final score.
"""
print(__doc__)

import numpy as np
from sklearn.metrics import mean_squared_error, make_scorer
from sklearn.model_selection import cross_val_score
from sklearn_extra.robust import make_huber_metric
from sklearn.linear_model import LinearRegression, HuberRegressor

robust_mse = make_huber_metric(mean_squared_error, c=9)
rng = np.random.RandomState(42)

X = rng.uniform(size=100)[:, np.newaxis]
y = 3 * X.ravel()
# Remark y <= 3

y[[42 // 2, 42, 42 * 2]] = 200 # outliers

print("Non robust error:")
for reg in [LinearRegression(), HuberRegressor()]:
print(
reg,
" mse : %.2F"
% (
np.mean(
cross_val_score(
reg, X, y, scoring=make_scorer(mean_squared_error)
)
)
),
)


print("\n")
print("Robust error:")
for reg in [LinearRegression(), HuberRegressor()]:
print(
reg,
" mse : %.2F"
% (
np.mean(
cross_val_score(reg, X, y, scoring=make_scorer(robust_mse))
)
),
)
3 changes: 3 additions & 0 deletions sklearn_extra/robust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
RobustWeightedKMeans,
RobustWeightedRegressor,
)
from sklearn_extra.robust.mean_estimators import huber, make_huber_metric

__all__ = [
"RobustWeightedClassifier",
"RobustWeightedKMeans",
"RobustWeightedRegressor",
"huber",
"make_huber_metric",
]
75 changes: 72 additions & 3 deletions sklearn_extra/robust/mean_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# License: BSD 3 clause

import numpy as np
from scipy.stats import iqr
from sklearn.metrics import mean_squared_error


def block_mom(X, k, random_state):
Expand Down Expand Up @@ -88,7 +90,7 @@ def median_of_means(X, k, random_state=np.random.RandomState(42)):
return median_of_means_blocked(x, blocks)[0]


def huber(X, c=None, T=20, tol=1e-3):
def huber(X, c=None, n_iter=20, tol=1e-3):
"""Compute the Huber estimator of location of X with parameter c

Parameters
Expand All @@ -104,7 +106,7 @@ def huber(X, c=None, T=20, tol=1e-3):
if c is None, the interquartile range (IQR) is used
as heuristic.

T : int, default = 20
n_iter : int, default = 20
Number of iterations of the algorithm.

tol : float, default=1e-3
Expand Down Expand Up @@ -138,7 +140,7 @@ def psisx(x, c):
last_mu = mu

# Run the iterative reweighting algorithm to compute M-estimator.
for t in range(T):
for t in range(n_iter):
# Compute the weights
w = psisx(x - mu, c_numeric)

Expand All @@ -156,3 +158,70 @@ def psisx(x, c):
last_mu = mu

return mu


def make_huber_metric(
score_func=mean_squared_error, sample_weight=None, c=None, n_iter=20
):
"""
Make a robust metric using Huber estimator.

TimotheeMathieu marked this conversation as resolved.
Show resolved Hide resolved
Read more in the :ref:`User Guide <make_huber_metric>`.

Parameters
----------

score_func : callable
Score function (or loss function) with signature
``score_func(y, y_pred, **kwargs)``.

sample_weight: array-like of shape (n_samples,), default=None
Sample weights.


c : float >0, default = None
parameter that control the robustness of the estimator.
c going to zero gives a behavior close to the median.
c going to infinity gives a behavior close to sample mean.
if c is None, the iqr (inter quartile range) is used as heuristic.

n_iter : int, default = 20
Number of iterations of the algorithm.

Return
------

Robust metric function, a callable with signature
``score_func(y, y_pred, **kwargs).

Examples
--------

>>> import numpy as np
>>> from sklearn.metrics import mean_squared_error
>>> from sklearn_extra.robust import make_huber_metric
>>> robust_mse = make_huber_metric(mean_squared_error, c=5)
>>> y_true = np.hstack([np.zeros(98), 20*np.ones(2)]) # corrupted test values
>>> np.random.shuffle(y_true) # shuffle them
>>> y_pred = np.zeros(100) # predicted values
>>> result = robust_mse(y_true, y_pred)
"""

def metric(y_true, y_pred):
# change size in order to use the raw multisample
# to have individual values
y1 = [y_true]
y2 = [y_pred]
values = score_func(
y1, y2, sample_weight=sample_weight, multioutput="raw_values"
)
if c is None:
c_ = iqr(values)
else:
c_ = c
if c_ == 0:
return np.median(values)
else:
return huber(values, c_, n_iter)

return metric
36 changes: 34 additions & 2 deletions sklearn_extra/robust/tests/test_mean_estimators.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import numpy as np
import pytest

from sklearn_extra.robust.mean_estimators import median_of_means, huber

from sklearn_extra.robust.mean_estimators import (
median_of_means,
huber,
make_huber_metric,
)
from sklearn.metrics import mean_squared_error, make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import HuberRegressor

rng = np.random.RandomState(42)

Expand Down Expand Up @@ -30,3 +36,29 @@ def test_huber():
mu = huber(X, c=0.5)
assert len(record) == 0
assert np.abs(mu) < 0.1


def test_robust_metric():
robust_mse = make_huber_metric(mean_squared_error, c=5)
y_true = np.hstack([np.zeros(95), 20 * np.ones(5)])
np.random.shuffle(y_true)
y_pred = np.zeros(100)

assert robust_mse(y_true, y_pred) < 1


def test_check_robust_cv():

robust_mse = make_huber_metric(mean_squared_error, c=9)
rng = np.random.RandomState(42)

X = rng.uniform(size=100)[:, np.newaxis]
y = 3 * X.ravel()

y[[42 // 2, 42, 42 * 2]] = 200 # outliers

huber_reg = HuberRegressor()
error_Hub_reg = error_ols = np.mean(
cross_val_score(huber_reg, X, y, scoring=make_scorer(robust_mse))
)
assert error_Hub_reg < 1