Skip to content

Commit

Permalink
Support init arguments in MNMG LogisticRegression (#5519)
Browse files Browse the repository at this point in the history
The init arguments are for LBFGS (the only algorithm in the current MNMG LogisticRegression).  

The key code changes should be a few lines after [PR 5516 for predict](#5516) gets merged. Key code changes can be reviewed from [here](https://github.com/rapidsai/cuml/pull/5519/files/d058d884c992661984224d0190c3bbcc0a23caf4..fbbaa5c6aef47ddc7100f5bea2a751851ca6d1b4)

Authors:
  - Jinfeng Li (https://github.com/lijinf2)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #5519
  • Loading branch information
lijinf2 authored Aug 1, 2023
1 parent 5a3309d commit 6fb5bf9
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 9 deletions.
110 changes: 107 additions & 3 deletions python/cuml/dask/linear_model/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,112 @@


class LogisticRegression(LinearRegression):
def __init__(self, *, client=None, verbose=False, **kwargs):
super().__init__(client=client, verbose=verbose, **kwargs)
"""
LogisticRegression is a linear model that is used to model probability of
occurrence of certain events, for example probability of success or fail of
an event.
cuML's dask Logistic Regression (multi-node multi-gpu) expects dask cuDF
DataFrame and provides an algorithms, L-BFGS, to fit the logistic model. It
currently supports single class, l2 regularization, and sigmoid loss.
Note that, just like in Scikit-learn, the bias will not be regularized.
Examples
--------
.. code-block:: python
>>> from dask_cuda import LocalCUDACluster
>>> from dask.distributed import Client
>>> import dask_cudf
>>> import cudf
>>> import numpy as np
>>> from cuml.dask.linear_model import LogisticRegression
>>> cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES="0,1")
>>> client = Client(cluster)
>>> X = cudf.DataFrame()
>>> X['col1'] = np.array([1,1,2,2], dtype = np.float32)
>>> X['col2'] = np.array([1,2,2,3], dtype = np.float32)
>>> y = cudf.Series(np.array([0.0, 0.0, 1.0, 1.0], dtype=np.float32))
>>> X_ddf = dask_cudf.from_cudf(X, npartitions=2)
>>> y_ddf = dask_cudf.from_cudf(y, npartitions=2)
>>> reg = LogisticRegression()
>>> reg.fit(X_ddf, y_ddf)
LogisticRegression()
>>> print(reg.coef_)
0 1
0 0.69861 0.570058
>>> print(reg.intercept_)
0 -2.188068
dtype: float32
>>> X_new = cudf.DataFrame()
>>> X_new['col1'] = np.array([1,5], dtype = np.float32)
>>> X_new['col2'] = np.array([2,5], dtype = np.float32)
>>> X_new_ddf = dask_cudf.from_cudf(X_new, npartitions=2)
>>> preds = reg.predict(X_new_ddf)
>>> print(preds.compute())
0 0.0
1 1.0
dtype: float32
Parameters
----------
tol : float (default = 1e-4)
Tolerance for stopping criteria.
The exact stopping conditions depend on the L-BFGS solver.
Check the solver's documentation for more details:
* :class:`Quasi-Newton (L-BFGS)<cuml.QN>`
C : float (default = 1.0)
Inverse of regularization strength; must be a positive float.
fit_intercept : boolean (default = True)
If True, the model tries to correct for the global mean of y.
If False, the model expects that you have centered the data.
max_iter : int (default = 1000)
Maximum number of iterations taken for the solvers to converge.
linesearch_max_iter : int (default = 50)
Max number of linesearch iterations per outer iteration used in the
solver.
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
output_type : {'input', 'array', 'dataframe', 'series', 'df_obj', \
'numba', 'cupy', 'numpy', 'cudf', 'pandas'}, default=None
Return results and set estimator attributes to the indicated output
type. If None, the output type set at the module level
(`cuml.global_settings.output_type`) will be used. See
:ref:`output-data-type-configuration` for more info.
Attributes
----------
coef_: dev array, dim (n_classes, n_features) or (n_classes, n_features+1)
The estimated coefficients for the linear regression model.
intercept_: device array (n_classes, 1)
The independent term. If `fit_intercept` is False, will be 0.
Notes
-----
cuML's LogisticRegression uses a different solver that the equivalent
Scikit-learn, except when there is no penalty and `solver=lbfgs` is
used in Scikit-learn. This can cause (smaller) differences in the
coefficients and predictions of the model, similar to
using different solvers in Scikit-learn.
For additional information, see `Scikit-learn's LogisticRegression
<https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html>`_.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

def fit(self, X, y):
"""
Expand Down Expand Up @@ -64,7 +168,7 @@ def _create_model(sessionId, datatype, **kwargs):
)

handle = get_raft_comm_state(sessionId, get_worker())["handle"]
return LogisticRegressionMG(handle=handle)
return LogisticRegressionMG(handle=handle, **kwargs)

@staticmethod
def _func_fit(f, data, n_rows, n_cols, partsToSizes, rank):
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class LogisticRegression(UniversalBase,
Attributes
----------
coef_: dev array, dim (n_classes, n_features) or (n_classes, n_features+1)
The estimated coefficients for the linear regression model.
The estimated coefficients for the logistic regression model.
intercept_: device array (n_classes, 1)
The independent term. If `fit_intercept` is False, will be 0.
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ cdef extern from "cuml/linear_model/qn_mg.hpp" namespace "ML::GLM::opg" nogil:

class LogisticRegressionMG(MGFitMixin, LogisticRegression):

def __init__(self, *, handle=None):
super().__init__(handle=handle)
def __init__(self, **kwargs):
super(LogisticRegressionMG, self).__init__(**kwargs)

@property
@cuml.internals.api_base_return_array_skipall
Expand Down
83 changes: 80 additions & 3 deletions python/cuml/tests/dask/test_dask_logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,90 @@ def imp():
assert_array_equal(preds, y, strict=True)


def test_lbfgs_init(client):
def imp():
import cuml.comm.serialize # NOQA

client.run(imp)

X = np.array([(1, 2), (1, 3), (2, 1), (3, 1)], dtype=np.float32)
y = np.array([1.0, 1.0, 0.0, 0.0], dtype=np.float32)

X_df, y_df = _prep_training_data(
c=client, X_train=X, y_train=y, partitions_per_worker=2
)

from cuml.dask.linear_model.logistic_regression import (
LogisticRegression as cumlLBFGS_dask,
)

def assert_params(
tol,
C,
fit_intercept,
max_iter,
linesearch_max_iter,
verbose,
output_type,
):

lr = cumlLBFGS_dask(
tol=tol,
C=C,
fit_intercept=fit_intercept,
max_iter=max_iter,
linesearch_max_iter=linesearch_max_iter,
verbose=verbose,
output_type=output_type,
)

lr.fit(X_df, y_df)
qnpams = lr.qnparams.params
assert qnpams["grad_tol"] == tol
assert qnpams["loss"] == 0 # "sigmoid" loss
assert qnpams["penalty_l1"] == 0.0
assert qnpams["penalty_l2"] == 1.0 / C
assert qnpams["fit_intercept"] == fit_intercept
assert qnpams["max_iter"] == max_iter
assert qnpams["linesearch_max_iter"] == linesearch_max_iter
assert (
qnpams["verbose"] == 5 if verbose is True else 4
) # cuml Verbosity Levels
assert (
lr.output_type == "input" if output_type is None else output_type
) # cuml.global_settings.output_type

assert_params(
tol=1e-4,
C=1.0,
fit_intercept=True,
max_iter=1000,
linesearch_max_iter=50,
verbose=False,
output_type=None,
)

assert_params(
tol=1e-6,
C=1.5,
fit_intercept=False,
max_iter=200,
linesearch_max_iter=100,
verbose=True,
output_type="cudf",
)


@pytest.mark.mg
@pytest.mark.parametrize("nrows", [1e5])
@pytest.mark.parametrize("ncols", [20])
@pytest.mark.parametrize("n_parts", [2, 23])
@pytest.mark.parametrize("fit_intercept", [False, True])
@pytest.mark.parametrize("datatype", [np.float32])
@pytest.mark.parametrize("delayed", [True, False])
def test_lbfgs(nrows, ncols, n_parts, datatype, delayed, client):
def test_lbfgs(
nrows, ncols, n_parts, fit_intercept, datatype, delayed, client
):
tolerance = 0.005

def imp():
Expand All @@ -203,12 +280,12 @@ def imp():

X_df, y_df = _prep_training_data(client, X, y, n_parts)

lr = cumlLBFGS_dask()
lr = cumlLBFGS_dask(fit_intercept=fit_intercept)
lr.fit(X_df, y_df)
lr_coef = lr.coef_.to_numpy()
lr_intercept = lr.intercept_.to_numpy()

sk_model = skLR()
sk_model = skLR(fit_intercept=fit_intercept)
sk_model.fit(X, y)
sk_coef = sk_model.coef_
sk_intercept = sk_model.intercept_
Expand Down

0 comments on commit 6fb5bf9

Please sign in to comment.