You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My laptop is equipped with multi-cores. Increasing n_jobs improve computing speed of fit and predict. However, it is inefficient for improving the speed of predict_survival_function.
Code Sample to Reproduce the Bug
import numpy as np
import pandas as pd
import time
np.random.seed(42)
def create_data(nb_events, nb_features):
np_X=np.random.rand(nb_events, nb_features)
np_time=np.random.rand(nb_events, 1)
np_is_living=np_X[:,0] < np_time[:,0]
y=np.empty(nb_events, dtype=[('event', '?'), ('time', '<f16')])
y['event']=np_is_living.reshape(-1)
y['time']=np_time.reshape(-1)
X=pd.DataFrame(np_X,columns=['f'+str(i) for i in range(1,nb_features+1)])
return X, y
X_train,y_train=create_data(nb_events=150, nb_features=8)
X_test,y_test=create_data(nb_events=150, nb_features=8)
from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42, n_jobs=8) #<------------- Increasing n_jobs does not improve predict_survival_function speed
print("Fitting ....")
st=time.time()
rsf.fit(X_train,y_train)
print(f"Fit time:{time.time()-st}")
st=time.time()
for i in range(100):
pred=rsf.predict_survival_function(X_test)
print(f"Predict time:{time.time()-st}")
Expected Results
Compared to n_jobs=1, n_jobs=8 should theoretically divided by 8 the computing time, at least I expect a division by 2.
My initial investigation revealed that most time is spend on copying data (the survival function for each terminal node) rather than traversing the tree.
Description
My laptop is equipped with multi-cores. Increasing n_jobs improve computing speed of
fit
andpredict
. However, it is inefficient for improving the speed ofpredict_survival_function
.Code Sample to Reproduce the Bug
Expected Results
Compared to n_jobs=1, n_jobs=8 should theoretically divided by 8 the computing time, at least I expect a division by 2.
Actual Results
n_jobs=8 is slower than n_jobs=1
Versions
System:
python: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]
executable: /home/pierrick/PycharmProjects/venv/bin/python
machine: Linux-5.19.0-45-generic-x86_64-with-glibc2.35
Python dependencies:
sklearn: 1.2.2
pip: 22.3.1
setuptools: 65.5.1
numpy: 1.24.3
scipy: 1.10.1
Cython: None
pandas: 2.0.2
matplotlib: 3.7.1
joblib: 1.2.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: openmp
internal_api: openmp
prefix: libgomp
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
version: None
num_threads: 20
internal_api: openblas
prefix: libopenblas
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-15028c96.3.21.so
version: 0.3.21
threading_layer: pthreads
architecture: Prescott
num_threads: 20
internal_api: openblas
prefix: libopenblas
filepath: /home/pierrick/PycharmProjects/venv/lib/python3.10/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
version: 0.3.18
threading_layer: pthreads
architecture: Prescott
num_threads: 20
sksurv: 0.21.0
numexpr: 2.8.4
osqp: 0.6.3
The text was updated successfully, but these errors were encountered: