Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Survival Random Forest predict_survival_function does not scale with n_jobs #382

Open
PierrickPochelu opened this issue Jun 20, 2023 · 1 comment

Comments

@PierrickPochelu
Copy link

PierrickPochelu commented Jun 20, 2023

Description

My laptop is equipped with multi-cores. Increasing n_jobs improve computing speed of fit and predict. However, it is inefficient for improving the speed of predict_survival_function.

Code Sample to Reproduce the Bug

import numpy as np
import pandas as pd
import time

np.random.seed(42)

def create_data(nb_events, nb_features):
    np_X=np.random.rand(nb_events, nb_features)
    np_time=np.random.rand(nb_events, 1)
    np_is_living=np_X[:,0] < np_time[:,0]
    y=np.empty(nb_events, dtype=[('event', '?'), ('time', '<f16')])
    y['event']=np_is_living.reshape(-1)
    y['time']=np_time.reshape(-1)
    X=pd.DataFrame(np_X,columns=['f'+str(i) for i in range(1,nb_features+1)])
    return X, y

X_train,y_train=create_data(nb_events=150, nb_features=8)
X_test,y_test=create_data(nb_events=150, nb_features=8)

from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42, n_jobs=8) #<------------- Increasing n_jobs does not improve predict_survival_function speed


print("Fitting ....")
st=time.time()
rsf.fit(X_train,y_train)
print(f"Fit time:{time.time()-st}")

st=time.time()
for i in range(100):
    pred=rsf.predict_survival_function(X_test)
print(f"Predict time:{time.time()-st}")

Expected Results
Compared to n_jobs=1, n_jobs=8 should theoretically divided by 8 the computing time, at least I expect a division by 2.

Actual Results
n_jobs=8 is slower than n_jobs=1

Versions
System:
python: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]
executable: /home/pierrick/PycharmProjects/venv/bin/python
machine: Linux-5.19.0-45-generic-x86_64-with-glibc2.35

Python dependencies:
sklearn: 1.2.2
pip: 22.3.1
setuptools: 65.5.1
numpy: 1.24.3
scipy: 1.10.1
Cython: None
pandas: 2.0.2
matplotlib: 3.7.1
joblib: 1.2.0
threadpoolctl: 3.1.0

Built with OpenMP: True

threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
num_threads: 20

   user_api: blas

internal_api: openblas
prefix: libopenblas
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-15028c96.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: Prescott
num_threads: 20

   user_api: blas

internal_api: openblas
prefix: libopenblas
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: Prescott
num_threads: 20
sksurv: 0.21.0
numexpr: 2.8.4
osqp: 0.6.3

@sebp
Copy link
Owner

sebp commented Jun 27, 2023

I can observe this to some extent.

My initial investigation revealed that most time is spend on copying data (the survival function for each terminal node) rather than traversing the tree.

Bildschirmfoto vom 2023-06-27 18-07-28

which happens at https://github.com/scikit-learn/scikit-learn/blob/9aaed498795f68e5956ea762fef9c440ca9eb239/sklearn/tree/_tree.pyx#L779-L780

Therefore, I assume parallelization is not as effective.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants