Skip to content

Commit

Permalink
FEA Add first version of to and from_sklearn APIs to estimators that …
Browse files Browse the repository at this point in the history
…support it
  • Loading branch information
dantegd committed Oct 8, 2024
1 parent 61f85a6 commit 8473259
Showing 1 changed file with 89 additions and 2 deletions.
91 changes: 89 additions & 2 deletions python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -722,4 +722,91 @@ class UniversalBase(Base):
return self

# return function result
return res
return

@staticmethod
def _get_serializer(protocol: str) -> Any:
"""
Get the appropriate serializer based on the specified protocol.
"""
if protocol == "pickle":
import pickle as serializer
elif protocol == "joblib":
import joblib as serializer
else:
raise TypeError(f"Protocol {protocol} not supported.")
return serializer

def to_sklearn(self,
protocol: str = "pickle",
filename: Optional[str] = None) -> None:
"""
Serialize the estimator to a Scikit-learn compatible file using the
specified protocol.

Parameters
----------
protocol : str, optional
The serialization protocol to use. Defaults to 'pickle'.
filename : str, optional
The name of the file where the model will be saved. If not provided, it defaults
to the class name with '_sklearn' appended.

Raises
------
AttributeError
If the model does not have a `_cpu_model` attribute.
TypeError
If the protocol is not supported.

"""
if filename is None:
filename = self.__class__.__name__ + "_sklearn"

serializer = self._get_serializer(protocol)

if not hasattr(self, '_cpu_model'):
self.import_cpu_model()
self.build_cpu_model()
self.gpu_to_cpu()

with open(filename, "wb") as f:
serializer.dump(self._cpu_model, f)

@classmethod
def from_sklearn(cls,
filename: str,
protocol: str = "pickle") -> 'Model':
"""
Create a cuML estimator from a pickle or joblib serialized
Scikit-learn model.

Parameters
----------
filename : str
The name of the file from which to load the model.
protocol : str, optional
The serialization protocol to use. Defaults to 'pickle'.

Returns
-------
Model
An instance of the class with the loaded model.

Raises
------
AttributeError
If the model does not have a `_cpu_model` attribute.
TypeError
If the protocol is not supported.
"""
estimator = cls()
serializer = cls._get_serializer(protocol)

with open(filename, "rb") as f:
state = serializer.load(f)

estimator._cpu_model = cls._cpu_model_class()
estimator._cpu_model.__dict__.update(state)
estimator.cpu_to_gpu()
return estimator

0 comments on commit 8473259

Please sign in to comment.