Skip to content

Commit

Permalink
Allow cuML MNMG estimators to be serialized (#5571)
Browse files Browse the repository at this point in the history
This PR :
- Modifies the base MNMG class to allow the distributed model to be serialized
~- Edit the NB and TF-IDF estimators to prevent their model from being serialized as Dask futures~
- Additionally, edit TF-IDF estimator to allow its use in an ML pipeline
- Adds a test for MNMG estimator serialization

Authors:
  - Victor Lafargue (https://github.com/viclafargue)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Simon Adorf (https://github.com/csadorf)

URL: #5571
  • Loading branch information
viclafargue authored Oct 6, 2023
1 parent 0dbb8c1 commit e5d6bc2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
14 changes: 14 additions & 0 deletions python/cuml/dask/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def __init__(self, *, client=None, verbose=False, **kwargs):

self.internal_model = None

def __getstate__(self):
internal_model = self._get_internal_model().result()
state = {
"verbose": self.verbose,
"kwargs": self.kwargs,
"datatype": getattr(self, "datatype", None),
"internal_model": internal_model,
}
return state

def __setstate__(self, state):
self._set_internal_model(state.pop("internal_model"))
self.__dict__.update(state)

def get_combined_model(self):
"""
Return single-GPU model for serialization
Expand Down
7 changes: 3 additions & 4 deletions python/cuml/dask/feature_extraction/text/tfidf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _set_idf_diag(model):
return model

@with_cupy_rmm
def fit(self, X):
def fit(self, X, y=None):

"""
Fit distributed TFIDF Transformer
Expand All @@ -135,7 +135,6 @@ def fit(self, X):
cuml.dask.feature_extraction.text.TfidfTransformer instance
"""

# Only Dask.Array supported for now
if not isinstance(X, dask.array.core.Array):
raise ValueError("Only dask.Array is supported for X")
Expand Down Expand Up @@ -179,7 +178,7 @@ def _get_part(parts, idx):
def _get_size(arrs):
return arrs.shape[0]

def fit_transform(self, X):
def fit_transform(self, X, y=None):
"""
Fit distributed TFIDFTransformer and then transform
the given set of data samples.
Expand All @@ -197,7 +196,7 @@ def fit_transform(self, X):
"""
return self.fit(X).transform(X)

def transform(self, X):
def transform(self, X, y=None):
"""
Use distributed TFIDFTransformer to transform the
given set of data samples.
Expand Down
18 changes: 18 additions & 0 deletions python/cuml/tests/dask/test_dask_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from distributed.protocol.serialize import serialize
from cuml.naive_bayes.naive_bayes import MultinomialNB
from cuml.internals.array_sparse import SparseCumlArray
from cuml.dask.linear_model import LinearRegression
from cuml.internals.safe_imports import gpu_only_import
from dask import array as da
from sklearn.datasets import make_regression
import numpy as np
import pickle

cp = gpu_only_import("cupy")
cupyx = gpu_only_import("cupyx")
Expand Down Expand Up @@ -62,3 +67,16 @@ def test_sparse_cumlarray_serialization():
stype, sbytes = serialize(X_m, serializers=["dask"])

assert stype["serializer"] == "dask"


def test_serialize_mnmg_model(client):
X, y = make_regression(n_samples=1000, n_features=20, random_state=0)
X, y = da.from_array(X), da.from_array(y)

model = LinearRegression(client)
model.fit(X, y)

pickled_model = pickle.dumps(model)
unpickled_model = pickle.loads(pickled_model)

assert np.allclose(unpickled_model.coef_, model.coef_)

0 comments on commit e5d6bc2

Please sign in to comment.