diff --git a/python/cuml/cuml/internals/base.pyx b/python/cuml/cuml/internals/base.pyx index c00ed17f98..6df868459a 100644 --- a/python/cuml/cuml/internals/base.pyx +++ b/python/cuml/cuml/internals/base.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -469,6 +469,20 @@ class Base(TagsMixin, func = nvtx_annotate(message=msg, domain="cuml_python")(func) setattr(self, func_name, func) + def to_sklearn(self, + protocol: str = "pickle", + filename: Optional[str] = None) -> None: + raise NotImplementedError("Estimator does not support exporting to " + "Scikit-learn yet.") + + + @classmethod + def from_sklearn(cls, + filename: str, + protocol: str = "pickle") -> 'Model': + raise NotImplementedError("Estimator does not support importing from " + "Scikit-learn yet.") + # Internal, non class owned helper functions def _check_output_type_str(output_str): @@ -723,3 +737,96 @@ class UniversalBase(Base): # return function result return res + + @staticmethod + def _get_serializer(protocol: str) -> Any: + """ + Get the appropriate serializer based on the specified protocol. + """ + if protocol == "pickle": + import pickle as serializer + elif protocol == "joblib": + import joblib as serializer + else: + raise TypeError(f"Protocol {protocol} not supported.") + return serializer + + def to_sklearn(self, + protocol: str = "pickle", + filename: Optional[str] = None) -> None: + """ + Serialize the estimator to a Scikit-learn compatible file using the + specified protocol. + + Parameters + ---------- + protocol : str, optional + The serialization protocol to use. Defaults to 'pickle'. + filename : str, optional + The name of the file where the model will be saved. If not provided, it defaults + to the class name with '_sklearn' appended. + + Raises + ------ + AttributeError + If the model does not have a `_cpu_model` attribute. + TypeError + If the protocol is not supported. + + """ + if filename is None: + filename = self.__class__.__name__ + "_sklearn" + + serializer = self._get_serializer(protocol) + + if not hasattr(self, '_cpu_model'): + self.import_cpu_model() + self.build_cpu_model() + self.gpu_to_cpu() + + with open(filename, "wb") as f: + serializer.dump(self._cpu_model, f) + + @classmethod + def from_sklearn(cls, + filename: str, + protocol: str = "pickle") -> 'Model': + """ + Create a cuML estimator from a pickle or joblib serialized + Scikit-learn model. + + Parameters + ---------- + filename : str + The name of the file from which to load the model. + protocol : str, optional + The serialization protocol to use. Defaults to 'pickle'. + + Returns + ------- + Model + An instance of the class with the loaded model. + + Raises + ------ + AttributeError + If the model does not have a `_cpu_model` attribute. + TypeError + If the protocol is not supported. + """ + estimator = cls() + serializer = cls._get_serializer(protocol) + + with open(filename, "rb") as f: + state = serializer.load(f) + + estimator.import_cpu_model() + estimator._cpu_model = state + estimator.cpu_to_gpu() + + # we need to set an output type here since + # we cannot infer from training args. + # Setting to numpy seems like a reasonable default + estimator.output_type = "numpy" + estimator.output_mem_type = MemoryType.host + return estimator diff --git a/python/cuml/cuml/tests/test_sklearn_import_export.py b/python/cuml/cuml/tests/test_sklearn_import_export.py new file mode 100644 index 0000000000..d97871aafa --- /dev/null +++ b/python/cuml/cuml/tests/test_sklearn_import_export.py @@ -0,0 +1,143 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +import tempfile +import os +from cuml import ( + KMeans, + DBSCAN, + PCA, + TruncatedSVD, + KernelRidge, + LinearRegression, + LogisticRegression, + ElasticNet, + Ridge, + Lasso, + TSNE, + NearestNeighbors, + KNeighborsClassifier, + KNeighborsRegressor, +) + + +# List of estimators with their types, parameters, and data-specific parameters +estimators = [ + (KMeans, 'clusterer', {'n_clusters': 3, 'random_state': 42}, {}), + (DBSCAN, 'clusterer', {}, {}), + (PCA, 'transformer', {'n_components': 5}, {}), + (TruncatedSVD, 'transformer', {'n_components': 5}, {}), + (LinearRegression, 'regressor', {}, {}), + (ElasticNet, 'regressor', {'max_iter': 1000}, {}), + (Ridge, 'regressor', {}, {}), + (Lasso, 'regressor', {'max_iter': 1000}, {}), + ( + LogisticRegression, + 'classifier', + {'random_state': 42, 'solver': 'liblinear', 'max_iter': 1000}, + {'n_classes': 2}, + ), + (TSNE, 'transformer', {'n_components': 2, 'random_state': 42}, {}), + (NearestNeighbors, 'neighbors', {'n_neighbors': 5}, {}), + (KNeighborsClassifier, 'classifier', {'n_neighbors': 5}, {'n_classes': 3}), + (KNeighborsRegressor, 'regressor', {'n_neighbors': 5}, {}), +] + + +def get_y(estimator_type: str, n_samples: int, data_params: dict): + if estimator_type in ['classifier', 'regressor']: + if estimator_type == 'classifier': + n_classes = data_params.get('n_classes', 2) + y = np.random.randint(0, n_classes, size=n_samples) + else: + y = np.random.rand(n_samples) + else: + y = None # Unsupervised methods don't use y + + return y + + +def predict_transform(estimator, estimator_type, X): + if estimator_type in ['regressor', 'classifier']: + output = estimator.predict(X) + elif estimator_type == 'clusterer': + if hasattr(estimator, 'predict'): + output = estimator.predict(X) + else: + output = estimator.labels_ + elif estimator_type == 'transformer': + if hasattr(estimator, 'transform'): + output = estimator.transform(X) + elif hasattr(estimator, 'embedding_'): + output = estimator.embedding_ + elif estimator_type == 'neighbors': + output = estimator.kneighbors(X) + else: + raise ValueError(f"Unknown estimator type: {estimator_type}") + + return output + + +@pytest.mark.parametrize("Estimator, estimator_type, est_params, data_params", estimators) +def test_estimator_to_from_sklearn(Estimator, estimator_type, est_params, data_params): + # Generate data based on estimator type + np.random.seed(42) + n_samples = 100 + n_features = 10 + X = np.random.rand(n_samples, n_features) + y = get_y(estimator_type, n_samples, data_params) + + # Instantiate estimator + est = Estimator(**est_params) + + # Fit estimator + if y is not None: + est.fit(X, y) + else: + if estimator_type == 'transformer' and hasattr(est, 'fit_transform') and not hasattr(est, 'transform'): + # For TSNE + output1 = est.fit_transform(X) + else: + est.fit(X) + + # Make predictions or transformations + + output1 = predict_transform(est, estimator_type, X) + + # Save and load the estimator using temporary file + with tempfile.NamedTemporaryFile(suffix='.pickle', delete=False) as tmp_file: + filename = tmp_file.name + try: + est.to_sklearn(filename=filename) + est2 = Estimator.from_sklearn(filename=filename) + finally: + # Clean up the temporary file + os.remove(filename) + + output2 = predict_transform(est2, estimator_type, X) + # Make predictions or transformations with the loaded estimator + + # Compare outputs + if estimator_type in ['regressor', 'transformer']: + assert np.allclose(output1, output2) + elif estimator_type in ['classifier', 'clusterer']: + assert np.array_equal(output1, output2) + elif estimator_type == 'neighbors': + distances1, indices1 = output1 + distances2, indices2 = output2 + assert np.allclose(distances1, distances2) + assert np.array_equal(indices1, indices2) \ No newline at end of file