Skip to content

Commit

Permalink
Fix segfault and other errors in ForestInference.load_from_sklearn (#…
Browse files Browse the repository at this point in the history
…5973)

Closes #5551

* Replace `np.float32` with `"float32"` so that we don't reference the `np` module. By the time `__dealloc__` method is called, modules may have already been unloaded.
* Improve the user experience by raising a helpful error when the user attempts to predict with an empty forest.

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5973
  • Loading branch information
hcho3 authored Jul 28, 2024
1 parent ba072b0 commit bc6f9b6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
28 changes: 19 additions & 9 deletions python/cuml/cuml/fil/fil.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ cdef class ForestInference_impl():
return None

def get_dtype(self):
dtype_array = [np.float32, np.float64]
dtype_array = ["float32", "float64"]
return dtype_array[self.forest_data.index()]

def get_algo(self, algo_str):
Expand Down Expand Up @@ -477,14 +477,24 @@ cdef class ForestInference_impl():
cdef uintptr_t preds_ptr
preds_ptr = preds.ptr

if fil_dtype == np.float32:
if fil_dtype == "float32":
if self.get_forest32() == NULL:
raise RuntimeError(
"Cannot call predict() with empty forest. "
"Please load the forest first with load() or "
"load_from_sklearn()")
predict(handle_[0],
self.get_forest32(),
<float*> preds_ptr,
<float*> X_ptr,
<size_t> n_rows,
<bool> predict_proba)
elif fil_dtype == np.float64:
elif fil_dtype == "float64":
if self.get_forest64() == NULL:
raise RuntimeError(
"Cannot call predict() with empty forest. "
"Please load the forest first with load() or "
"load_from_sklearn()")
predict(handle_[0],
self.get_forest64(),
<double*> preds_ptr,
Expand All @@ -493,7 +503,7 @@ cdef class ForestInference_impl():
<bool> predict_proba)
else:
# should not reach here
assert False, 'invalid fil_dtype, must be np.float32 or np.float64'
assert False, 'invalid fil_dtype, must be float32 or float64'

self.handle.sync()

Expand Down Expand Up @@ -557,15 +567,15 @@ cdef class ForestInference_impl():
def __dealloc__(self):
cdef handle_t* handle_ = <handle_t*><size_t>self.handle.getHandle()
fil_dtype = self.get_dtype()
if fil_dtype == np.float32:
if fil_dtype == "float32":
if self.get_forest32() != NULL:
free[float](handle_[0], self.get_forest32())
elif fil_dtype == np.float64:
elif fil_dtype == "float64":
if self.get_forest64() != NULL:
free[double](handle_[0], self.get_forest64())
else:
# should not reach here
assert False, 'invalid fil_dtype, must be np.float32 or np.float64'
assert False, 'invalid fil_dtype, must be float32 or float64'


class ForestInference(Base,
Expand Down Expand Up @@ -747,7 +757,7 @@ class ForestInference(Base,
Optional 'out' location to store inference results

safe_dtype_conversion : bool (default = False)
FIL converts data to np.float32 when needed. Set this parameter to
FIL converts data to float32 when needed. Set this parameter to
True to enable checking for information loss during that
conversion, but note that this check can have a significant
performance penalty. Parameter will be dropped in a future
Expand Down Expand Up @@ -776,7 +786,7 @@ class ForestInference(Base,
Optional 'out' location to store inference results

safe_dtype_conversion : bool (default = False)
FIL converts data to np.float32 when needed. Set this parameter to
FIL converts data to float32 when needed. Set this parameter to
True to enable checking for information loss during that
conversion, but note that this check can have a significant
performance penalty. Parameter will be dropped in a future
Expand Down
4 changes: 2 additions & 2 deletions python/cuml/cuml/tests/dask/test_dask_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def test_rf_broadcast(model_type, fit_broadcast, transform_broadcast, client):

if model_type == "classification":
X, y = make_classification(
n_samples=n_workers * 1000,
n_samples=n_workers * 10000,
n_features=20,
n_informative=15,
n_classes=4,
Expand All @@ -663,7 +663,7 @@ def test_rf_broadcast(model_type, fit_broadcast, transform_broadcast, client):
y = y.astype(np.int32)
else:
X, y = make_regression(
n_samples=n_workers * 1000,
n_samples=n_workers * 10000,
n_features=20,
n_informative=5,
random_state=123,
Expand Down

0 comments on commit bc6f9b6

Please sign in to comment.