From cc22891d198d53452d048fee85b6a85d22a9ada7 Mon Sep 17 00:00:00 2001 From: Dante Gama Dessavre Date: Wed, 4 Oct 2023 17:02:15 -0500 Subject: [PATCH] Fixes for CPU package (#5599) This PR fixes a few issues ran into while testing the different supported algorithms: - Add pandas as runtime dependency (which already is of the GPU package through cuDF) - Use np.dtype as opposed to cp.dtype, they are the same object in the changes of the PR, and other similar changes to not trigger cudf/cupy/numba calls at import time, which cause runtime issues in some algos. - Other small fixes that can cause issues at runtime Authors: - Dante Gama Dessavre (https://github.com/dantegd) Approvers: - Ray Douglass (https://github.com/raydouglass) - William Hicks (https://github.com/wphicks) URL: https://github.com/rapidsai/cuml/pull/5599 --- conda/recipes/cuml-cpu/meta.yaml | 3 +- python/cuml/cluster/__init__.py | 3 -- python/cuml/common/__init__.py | 4 ++- python/cuml/common/kernel_utils.py | 4 --- python/cuml/common/sparsefuncs.py | 2 +- python/cuml/internals/__init__.py | 6 +++- python/cuml/internals/base.pyx | 37 +++++++++++++++--------- python/cuml/internals/global_settings.py | 2 -- python/cuml/internals/input_utils.py | 27 ++++++++++++----- python/cuml/internals/type_utils.py | 10 +++++-- python/cuml/linear_model/__init__.py | 8 +++-- python/cuml/manifold/__init__.py | 5 +++- python/cuml/manifold/umap.pyx | 21 +++++++++----- 13 files changed, 83 insertions(+), 49 deletions(-) diff --git a/conda/recipes/cuml-cpu/meta.yaml b/conda/recipes/cuml-cpu/meta.yaml index e52c10d3b9..d4497a65fb 100644 --- a/conda/recipes/cuml-cpu/meta.yaml +++ b/conda/recipes/cuml-cpu/meta.yaml @@ -34,6 +34,7 @@ requirements: run: - python x.x - numpy + - pandas - scikit-learn=1.2 - hdbscan<=0.8.30 - umap-learn=0.5.3 @@ -41,7 +42,7 @@ requirements: tests: # [linux64] imports: # [linux64] - - cuml-cpu # [linux64] + - cuml # [linux64] about: home: http://rapids.ai/ diff --git a/python/cuml/cluster/__init__.py b/python/cuml/cluster/__init__.py index c02ddc872e..41cb8176f5 100644 --- a/python/cuml/cluster/__init__.py +++ b/python/cuml/cluster/__init__.py @@ -14,9 +14,6 @@ # limitations under the License. # -from cuml.cluster.dbscan import DBSCAN -from cuml.cluster.kmeans import KMeans -from cuml.cluster.agglomerative import AgglomerativeClustering from cuml.internals.device_support import GPU_ENABLED from cuml.cluster.hdbscan import HDBSCAN diff --git a/python/cuml/common/__init__.py b/python/cuml/common/__init__.py index 6ab43e5757..6a46462878 100644 --- a/python/cuml/common/__init__.py +++ b/python/cuml/common/__init__.py @@ -17,6 +17,7 @@ # from cuml.internals.array import CumlArray # from cuml.internals.array_sparse import SparseCumlArray +from cuml.internals.available_devices import is_cuda_available from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray @@ -39,7 +40,8 @@ from cuml.common.device_selection import using_device_type -from cuml.common.pointer_utils import device_of_gpu_matrix +if is_cuda_available(): + from cuml.common.pointer_utils import device_of_gpu_matrix # legacy to be removed after complete CumlAray migration diff --git a/python/cuml/common/kernel_utils.py b/python/cuml/common/kernel_utils.py index 356ee8c60e..86d6ad831a 100644 --- a/python/cuml/common/kernel_utils.py +++ b/python/cuml/common/kernel_utils.py @@ -26,10 +26,6 @@ # Mapping of common PyData dtypes to their corresponding C-primitive dtype_str_map = { - cp.dtype("float32"): "float", - cp.dtype("float64"): "double", - cp.dtype("int32"): "int", - cp.dtype("int64"): "long long int", np.dtype("float32"): "float", np.dtype("float64"): "double", np.dtype("int32"): "int", diff --git a/python/cuml/common/sparsefuncs.py b/python/cuml/common/sparsefuncs.py index 000e80c33a..4648163dc6 100644 --- a/python/cuml/common/sparsefuncs.py +++ b/python/cuml/common/sparsefuncs.py @@ -141,7 +141,7 @@ def csr_diag_mul(X, y, inplace=True): @cuml.internals.api_return_any() def create_csr_matrix_from_count_df( - count_df, empty_doc_ids, n_doc, n_features, dtype=cp.float32 + count_df, empty_doc_ids, n_doc, n_features, dtype=np.float32 ): """ Create a sparse matrix from the count of tokens by document diff --git a/python/cuml/internals/__init__.py b/python/cuml/internals/__init__.py index a0d082df25..e8d989fd8e 100644 --- a/python/cuml/internals/__init__.py +++ b/python/cuml/internals/__init__.py @@ -14,6 +14,7 @@ # limitations under the License. # +from cuml.internals.available_devices import is_cuda_available from cuml.internals.base_helpers import BaseMetaClass, _tags_class_and_instance from cuml.internals.api_decorators import ( _deprecate_pos_args, @@ -36,5 +37,8 @@ set_api_output_dtype, set_api_output_type, ) -from cuml.internals.internals import GraphBasedDimRedCallback + +if is_cuda_available(): + from cuml.internals.internals import GraphBasedDimRedCallback + from cuml.internals.constants import CUML_WRAPPED_FLAG diff --git a/python/cuml/internals/base.pyx b/python/cuml/internals/base.pyx index 889425994d..4fb03fdac9 100644 --- a/python/cuml/internals/base.pyx +++ b/python/cuml/internals/base.pyx @@ -30,10 +30,8 @@ nvtx_annotate = gpu_only_import_from("nvtx", "annotate", alt=null_decorator) import cuml import cuml.common -import cuml.common.cuda import cuml.internals.logger as logger import cuml.internals -import pylibraft.common.handle import cuml.internals.input_utils from cuml.internals.available_devices import is_cuda_available from cuml.internals.device_type import DeviceType @@ -61,6 +59,11 @@ cp_ndarray = gpu_only_import_from('cupy', 'ndarray') cp = gpu_only_import('cupy') +IF GPUBUILD == 1: + import pylibraft.common.handle + import cuml.common.cuda + + class Base(TagsMixin, metaclass=cuml.internals.BaseMetaClass): """ @@ -178,7 +181,7 @@ class Base(TagsMixin, # stream and handle example: - stream = cuml.cuda.Stream() + stream = cuml.common.cuda.Stream() handle = pylibraft.common.Handle(stream=stream) algo = MyAlgo(handle=handle) @@ -201,17 +204,23 @@ class Base(TagsMixin, Constructor. All children must call init method of this base class. """ - self.handle = pylibraft.common.handle.Handle() if handle is None \ - else handle - - # Internally, self.verbose follows the spdlog/c++ standard of - # 0 is most logging, and logging decreases from there. - # So if the user passes an int value for logging, we convert it. - if verbose is True: - self.verbose = logger.level_debug - elif verbose is False: - self.verbose = logger.level_info - else: + IF GPUBUILD == 1: + self.handle = pylibraft.common.handle.Handle() if handle is None \ + else handle + ELSE: + self.handle = None + + IF GPUBUILD == 1: + # Internally, self.verbose follows the spdlog/c++ standard of + # 0 is most logging, and logging decreases from there. + # So if the user passes an int value for logging, we convert it. + if verbose is True: + self.verbose = logger.level_debug + elif verbose is False: + self.verbose = logger.level_info + else: + self.verbose = verbose + ELSE: self.verbose = verbose self.output_type = _check_output_type_str( diff --git a/python/cuml/internals/global_settings.py b/python/cuml/internals/global_settings.py index adb9e0da44..ea899d91b1 100644 --- a/python/cuml/internals/global_settings.py +++ b/python/cuml/internals/global_settings.py @@ -18,7 +18,6 @@ import threading from cuml.internals.available_devices import is_cuda_available from cuml.internals.device_type import DeviceType -from cuml.internals.logger import warn from cuml.internals.mem_type import MemoryType from cuml.internals.safe_imports import cpu_only_import, gpu_only_import @@ -38,7 +37,6 @@ def __init__(self): default_device_type = DeviceType.device default_memory_type = MemoryType.device else: - warn("GPU will not be used") default_device_type = DeviceType.host default_memory_type = MemoryType.host self.shared_state = { diff --git a/python/cuml/internals/input_utils.py b/python/cuml/internals/input_utils.py index e4c2f97738..bb9e8bc3e3 100644 --- a/python/cuml/internals/input_utils.py +++ b/python/cuml/internals/input_utils.py @@ -71,24 +71,35 @@ CumlArray: "cuml", SparseCumlArray: "cuml", np_ndarray: "numpy", - cp_ndarray: "cupy", - CudfSeries: "cudf", - CudfDataFrame: "cudf", PandasSeries: "pandas", PandasDataFrame: "pandas", - NumbaDeviceNDArrayBase: "numba", } + +try: + _input_type_to_str[cp_ndarray] = "cupy" + _input_type_to_str[CudfSeries] = "cudf" + _input_type_to_str[CudfDataFrame] = "cudf" + _input_type_to_str[NumbaDeviceNDArrayBase] = "numba" +except UnavailableError: + pass + + _input_type_to_mem_type = { np_ndarray: MemoryType.host, - cp_ndarray: MemoryType.device, - CudfSeries: MemoryType.device, - CudfDataFrame: MemoryType.device, PandasSeries: MemoryType.host, PandasDataFrame: MemoryType.host, - NumbaDeviceNDArrayBase: MemoryType.device, } + +try: + _input_type_to_mem_type[cp_ndarray] = MemoryType.device + _input_type_to_mem_type[CudfSeries] = MemoryType.device + _input_type_to_mem_type[CudfDataFrame] = MemoryType.device + _input_type_to_mem_type[NumbaDeviceNDArrayBase] = MemoryType.device +except UnavailableError: + pass + _SPARSE_TYPES = [SparseCumlArray] try: diff --git a/python/cuml/internals/type_utils.py b/python/cuml/internals/type_utils.py index 02002deb61..d889a4968a 100644 --- a/python/cuml/internals/type_utils.py +++ b/python/cuml/internals/type_utils.py @@ -16,12 +16,16 @@ import functools import typing -from cuml.internals.safe_imports import gpu_only_import +from cuml.internals.safe_imports import gpu_only_import, UnavailableError cp = gpu_only_import("cupy") -# Those are the only data types supported by cupyx.scipy.sparse matrices. -CUPY_SPARSE_DTYPES = [cp.float32, cp.float64, cp.complex64, cp.complex128] + +try: + # Those are the only data types supported by cupyx.scipy.sparse matrices. + CUPY_SPARSE_DTYPES = [cp.float32, cp.float64, cp.complex64, cp.complex128] +except UnavailableError: + CUPY_SPARSE_DTYPES = [] # Use _DecoratorType as a type variable for decorators. See: # https://github.com/python/mypy/pull/8336/files#diff-eb668b35b7c0c4f88822160f3ca4c111f444c88a38a3b9df9bb8427131538f9cR260 diff --git a/python/cuml/linear_model/__init__.py b/python/cuml/linear_model/__init__.py index 7f4a71a510..28a8bb6dc8 100644 --- a/python/cuml/linear_model/__init__.py +++ b/python/cuml/linear_model/__init__.py @@ -14,11 +14,13 @@ # limitations under the License. # - +from cuml.internals.device_support import GPU_ENABLED from cuml.linear_model.elastic_net import ElasticNet from cuml.linear_model.lasso import Lasso from cuml.linear_model.linear_regression import LinearRegression from cuml.linear_model.logistic_regression import LogisticRegression -from cuml.linear_model.mbsgd_classifier import MBSGDClassifier -from cuml.linear_model.mbsgd_regressor import MBSGDRegressor from cuml.linear_model.ridge import Ridge + +if GPU_ENABLED: + from cuml.linear_model.mbsgd_classifier import MBSGDClassifier + from cuml.linear_model.mbsgd_regressor import MBSGDRegressor diff --git a/python/cuml/manifold/__init__.py b/python/cuml/manifold/__init__.py index db411421b7..931e5c33e0 100644 --- a/python/cuml/manifold/__init__.py +++ b/python/cuml/manifold/__init__.py @@ -14,5 +14,8 @@ # limitations under the License. # +from cuml.internals.available_devices import is_cuda_available from cuml.manifold.umap import UMAP -from cuml.manifold.t_sne import TSNE + +if is_cuda_available(): + from cuml.manifold.t_sne import TSNE diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 47af577529..cc9c492fba 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -36,19 +36,13 @@ import cuml.internals from cuml.internals.base import UniversalBase from cuml.common.doc_utils import generate_docstring from cuml.internals import logger +from cuml.internals.available_devices import is_cuda_available from cuml.internals.input_utils import input_to_cuml_array from cuml.internals.array import CumlArray from cuml.internals.array_sparse import SparseCumlArray from cuml.internals.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse -from cuml.manifold.simpl_set import fuzzy_simplicial_set # no-cython-lint -from cuml.manifold.simpl_set import simplicial_set_embedding # no-cython-lint -# TODO: These two symbols are considered part of the public API of this module -# which is why imports should not be removed. The no-cython-lint markers can be -# replaced with an explicit __all__ specifications once -# https://github.com/MarcoGorelli/cython-lint/issues/80 is resolved. - from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.internals.api_decorators import device_interop_preparation from cuml.internals.api_decorators import enable_device_interop @@ -58,6 +52,19 @@ rmm = gpu_only_import('rmm') from libc.stdint cimport uintptr_t +if is_cuda_available(): + from cuml.manifold.simpl_set import fuzzy_simplicial_set # no-cython-lint + from cuml.manifold.simpl_set import simplicial_set_embedding # no-cython-lint + # TODO: These two symbols are considered part of the public API of this module + # which is why imports should not be removed. The no-cython-lint markers can be + # replaced with an explicit __all__ specifications once + # https://github.com/MarcoGorelli/cython-lint/issues/80 is resolved. +else: + # if no GPU is present, we import the UMAP equivalents + from umap.umap_ import fuzzy_simplicial_set # no-cython-lint + from umap.umap_ import simplicial_set_embedding # no-cython-lint + + IF GPUBUILD == 1: from libc.stdlib cimport free from cuml.manifold.umap_utils cimport *