diff --git a/python/cuml/cuml/fil/fil.pyx b/python/cuml/cuml/fil/fil.pyx index 170f18992e..44c656bef4 100644 --- a/python/cuml/cuml/fil/fil.pyx +++ b/python/cuml/cuml/fil/fil.pyx @@ -356,7 +356,7 @@ cdef class ForestInference_impl(): return None def get_dtype(self): - dtype_array = [np.float32, np.float64] + dtype_array = ["float32", "float64"] return dtype_array[self.forest_data.index()] def get_algo(self, algo_str): @@ -477,14 +477,24 @@ cdef class ForestInference_impl(): cdef uintptr_t preds_ptr preds_ptr = preds.ptr - if fil_dtype == np.float32: + if fil_dtype == "float32": + if self.get_forest32() == NULL: + raise RuntimeError( + "Cannot call predict() with empty forest. " + "Please load the forest first with load() or " + "load_from_sklearn()") predict(handle_[0], self.get_forest32(), preds_ptr, X_ptr, n_rows, predict_proba) - elif fil_dtype == np.float64: + elif fil_dtype == "float64": + if self.get_forest64() == NULL: + raise RuntimeError( + "Cannot call predict() with empty forest. " + "Please load the forest first with load() or " + "load_from_sklearn()") predict(handle_[0], self.get_forest64(), preds_ptr, @@ -493,7 +503,7 @@ cdef class ForestInference_impl(): predict_proba) else: # should not reach here - assert False, 'invalid fil_dtype, must be np.float32 or np.float64' + assert False, 'invalid fil_dtype, must be float32 or float64' self.handle.sync() @@ -557,15 +567,15 @@ cdef class ForestInference_impl(): def __dealloc__(self): cdef handle_t* handle_ = self.handle.getHandle() fil_dtype = self.get_dtype() - if fil_dtype == np.float32: + if fil_dtype == "float32": if self.get_forest32() != NULL: free[float](handle_[0], self.get_forest32()) - elif fil_dtype == np.float64: + elif fil_dtype == "float64": if self.get_forest64() != NULL: free[double](handle_[0], self.get_forest64()) else: # should not reach here - assert False, 'invalid fil_dtype, must be np.float32 or np.float64' + assert False, 'invalid fil_dtype, must be float32 or float64' class ForestInference(Base, @@ -747,7 +757,7 @@ class ForestInference(Base, Optional 'out' location to store inference results safe_dtype_conversion : bool (default = False) - FIL converts data to np.float32 when needed. Set this parameter to + FIL converts data to float32 when needed. Set this parameter to True to enable checking for information loss during that conversion, but note that this check can have a significant performance penalty. Parameter will be dropped in a future @@ -776,7 +786,7 @@ class ForestInference(Base, Optional 'out' location to store inference results safe_dtype_conversion : bool (default = False) - FIL converts data to np.float32 when needed. Set this parameter to + FIL converts data to float32 when needed. Set this parameter to True to enable checking for information loss during that conversion, but note that this check can have a significant performance penalty. Parameter will be dropped in a future diff --git a/python/cuml/cuml/tests/dask/test_dask_random_forest.py b/python/cuml/cuml/tests/dask/test_dask_random_forest.py index 38596b2e69..e8311c9868 100644 --- a/python/cuml/cuml/tests/dask/test_dask_random_forest.py +++ b/python/cuml/cuml/tests/dask/test_dask_random_forest.py @@ -653,7 +653,7 @@ def test_rf_broadcast(model_type, fit_broadcast, transform_broadcast, client): if model_type == "classification": X, y = make_classification( - n_samples=n_workers * 1000, + n_samples=n_workers * 10000, n_features=20, n_informative=15, n_classes=4, @@ -663,7 +663,7 @@ def test_rf_broadcast(model_type, fit_broadcast, transform_broadcast, client): y = y.astype(np.int32) else: X, y = make_regression( - n_samples=n_workers * 1000, + n_samples=n_workers * 10000, n_features=20, n_informative=5, random_state=123,