From 84732591ebf2b46fe92306d0d8bcc54ccbdd4077 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Mon, 7 Oct 2024 21:20:05 -0500 Subject: [PATCH] FEA Add first version of to and from_sklearn APIs to estimators that support it --- python/cuml/cuml/internals/base.pyx | 91 ++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 2 deletions(-) diff --git a/python/cuml/cuml/internals/base.pyx b/python/cuml/cuml/internals/base.pyx index c00ed17f98..91e5714641 100644 --- a/python/cuml/cuml/internals/base.pyx +++ b/python/cuml/cuml/internals/base.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -722,4 +722,91 @@ class UniversalBase(Base): return self # return function result - return res + return + + @staticmethod + def _get_serializer(protocol: str) -> Any: + """ + Get the appropriate serializer based on the specified protocol. + """ + if protocol == "pickle": + import pickle as serializer + elif protocol == "joblib": + import joblib as serializer + else: + raise TypeError(f"Protocol {protocol} not supported.") + return serializer + + def to_sklearn(self, + protocol: str = "pickle", + filename: Optional[str] = None) -> None: + """ + Serialize the estimator to a Scikit-learn compatible file using the + specified protocol. + + Parameters + ---------- + protocol : str, optional + The serialization protocol to use. Defaults to 'pickle'. + filename : str, optional + The name of the file where the model will be saved. If not provided, it defaults + to the class name with '_sklearn' appended. + + Raises + ------ + AttributeError + If the model does not have a `_cpu_model` attribute. + TypeError + If the protocol is not supported. + + """ + if filename is None: + filename = self.__class__.__name__ + "_sklearn" + + serializer = self._get_serializer(protocol) + + if not hasattr(self, '_cpu_model'): + self.import_cpu_model() + self.build_cpu_model() + self.gpu_to_cpu() + + with open(filename, "wb") as f: + serializer.dump(self._cpu_model, f) + + @classmethod + def from_sklearn(cls, + filename: str, + protocol: str = "pickle") -> 'Model': + """ + Create a cuML estimator from a pickle or joblib serialized + Scikit-learn model. + + Parameters + ---------- + filename : str + The name of the file from which to load the model. + protocol : str, optional + The serialization protocol to use. Defaults to 'pickle'. + + Returns + ------- + Model + An instance of the class with the loaded model. + + Raises + ------ + AttributeError + If the model does not have a `_cpu_model` attribute. + TypeError + If the protocol is not supported. + """ + estimator = cls() + serializer = cls._get_serializer(protocol) + + with open(filename, "rb") as f: + state = serializer.load(f) + + estimator._cpu_model = cls._cpu_model_class() + estimator._cpu_model.__dict__.update(state) + estimator.cpu_to_gpu() + return estimator