Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RandomSurvivalForest is not consistent making predictions of survival functions at certain times #249

Closed
alonsosilvaallende opened this issue Feb 16, 2022 · 8 comments · Fixed by #375
Labels

Comments

@alonsosilvaallende
Copy link

alonsosilvaallende commented Feb 16, 2022

Describe the bug

RandomSurvivalForest is not consistent making predictions of survival functions (predict_survival_function) at certain times. I would expect that RandomSurvivalForest can predict survival functions when the times are within the intervals of the training and validation times. However, while the intervals of both the training and validation times are [0,182], RandomSurvivalForest cannot predict the survival function on times between [0,175]. Note that this doesn't happen with CoxPHSurvivalAnalysis
or GradientBoostingSurvivalAnalysis

Code Sample to Reproduce the Bug

import numpy as np
import pandas as pd
import statsmodels.api as sm
pharmacoSmoking = sm.datasets.get_rdataset("pharmacoSmoking", "asaur")
data = pharmacoSmoking.data
data = data.drop(columns=["id","ageGroup2","ageGroup4"]) # Drop redundant information and ids
from sksurv.datasets import get_x_y
X, y = get_x_y(data, attr_labels=["relapse", "ttr"], pos_label=True)
for c in X.columns:
    if X[c].dtype.kind not in ['i', 'f']:
        X[c] = X[c].astype("category")
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sksurv.preprocessing import OneHotEncoder
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import integrated_brier_score
def get_ibs(seed, times):
  X_trn, X_val, y_trn, y_val = train_test_split(X, y, random_state=seed)
  print(f"Minimum training time: {y_trn['ttr'].min()}")
  print(f"Maximum training time: {y_trn['ttr'].max()}")
  print(f"Minimum validation time: {y_val['ttr'].min()}")
  print(f"Maximum validation time: {y_val['ttr'].max()}")
  enc = OneHotEncoder()
  scaler = StandardScaler()
  X_trn = enc.fit_transform(X_trn)
  X_trn = pd.DataFrame(scaler.fit_transform(X_trn), columns=X_trn.columns)
  X_val = enc.transform(X_val)
  X_val = pd.DataFrame(scaler.transform(X_val), columns=X_val.columns)
  rsf = RandomSurvivalForest(random_state=42)
  rsf.fit(X_trn, y_trn)
  survs = rsf.predict_survival_function(X_val)
  preds = np.asarray([[fn(t) for t in times] for fn in survs])
  return integrated_brier_score(y_trn, y_val, preds, times)
print(f"Integrated Brier Score: {get_ibs(0, np.arange(0,170))}")
print(f"Integrated Brier Score: {get_ibs(0, np.arange(0,175))}")

Expected Results

Minimum training time: 0.0
Maximum training time: 182.0
Minimum validation time: 0.0
Maximum validation time: 182.0
Integrated Brier Score: 0.21497683508486282
Minimum training time: 0.0
Maximum training time: 182.0
Minimum validation time: 0.0
Maximum validation time: 182.0
Integrated Brier Score: 0.21497683508486282

Actual Results

Minimum training time: 0.0
Maximum training time: 182.0
Minimum validation time: 0.0
Maximum validation time: 182.0
Integrated Brier Score: 0.21497683508486282
Minimum training time: 0.0
Maximum training time: 182.0
Minimum validation time: 0.0
Maximum validation time: 182.0

ValueError                                Traceback (most recent call last)
[<ipython-input-3-fbfef721f54b>](https://localhost:8080/#) in <module>()
     33   return integrated_brier_score(y_trn, y_val, preds, times)
     34 print(f"Integrated Brier Score: {get_ibs(0, np.arange(0,170))}")
---> 35 print(f"Integrated Brier Score: {get_ibs(0, np.arange(0,175))}")

3 frames
[/usr/local/lib/python3.7/dist-packages/sksurv/functions.py](https://localhost:8080/#) in __call__(self, x)
     65         if numpy.min(x) < self.x[0] or numpy.max(x) > self.x[-1]:
     66             raise ValueError(
---> 67                 "x must be within [%f; %f]" % (self.x[0], self.x[-1]))
     68         i = numpy.searchsorted(self.x, x, side='left')
     69         not_exact = self.x[i] != x

ValueError: x must be within [0.000000; 170.000000]

Versions
Please execute the following snippet and paste the output below.

import sklearn; sklearn.show_versions()
import sksurv; print("sksurv:", sksurv.__version__)
import cvxopt; print("cvxopt:", cvxopt.__version__)
import cvxpy; print("cvxpy:", cvxpy.__version__)
import numexpr; print("numexpr:", numexpr.__version__)
import osqp; print("osqp:", osqp.OSQP().version())

System:
python: 3.7.12 (default, Jan 15 2022, 18:48:18) [GCC 7.5.0]
executable: /usr/bin/python3
machine: Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic

Python dependencies:
pip: 21.1.3
setuptools: 57.4.0
sklearn: 1.0.2
numpy: 1.21.5
scipy: 1.4.1
Cython: 0.29.27
pandas: 1.3.5
matplotlib: 3.2.2
joblib: 1.1.0
threadpoolctl: 3.1.0

Built with OpenMP: True
sksurv: 0.17.0
cvxopt: 1.2.7
cvxpy: 1.0.31
numexpr: 2.8.1
osqp: 0.6.2

@sebp
Copy link
Owner

sebp commented Sep 7, 2022

There's indeed an inconsistency. For RandomSurvivalForest and SurvivalTree, only the domain of the survival function is limited to the range of event times, whereas for CoxPHSurvivalAnalysis it is limited to any time point, whether an event has been observed or not.

@cpoerschke
Copy link
Contributor

#363 adds a test describing (my understanding of) the current behaviour described in this issue.

Wondering what a fix might look like?

  • All models to support all-times range?
  • All models to only support event-times range?
  • Some models supporting all-times range and other models only supporting event-times range with some way to tell which is supported before making the predict_survival_function or fn(t) call?
  • Something else?

@alonsosilvaallende
Copy link
Author

@cpoerschke I guess the question goes to @sebp however for me all models should support all-times range. If they give wrong predictions outside the trained ranges so be it.

@sebp
Copy link
Owner

sebp commented May 16, 2023

If with "all-times ranges" you mean the range of the training data (disregarding censoring status), I would agree.

It would also be worthwhile to check how random survival forest or gradient boosting implementation in R handle this.

@alonsosilvaallende
Copy link
Author

Thank you very much @sebp I don't know how those implementations do it in R, however, scikit-learn in random forest model doesn't put constraints on users if they want to go outside the training data (even though the results are obviously not good outside the training region). Here is a gist of that behavior that you can run in colab as well: https://gist.github.com/alonsosilvaallende/ef813de35a8b0f4328b451aea46b9c48

@cpoerschke
Copy link
Contributor

414354a explores clipping inputs on an opt-in basis.

@sebp
Copy link
Owner

sebp commented Jun 5, 2023

@cpoerschke Thanks for adding the clipping option. This might be a good option for advanced users that are aware of the risks. For time points that are smaller than the smallest time point in the training data, it would actually be okay to predict a survival probability of 1. For time points beyond the far end of the time axis, predictions would be very speculative.

Regarding that RandomSurvivalForest currently limits the interval to event times, I'm probably going to fix that together with #343.

sebp added a commit that referenced this issue Jun 11, 2023
Values outside the interval specified by `self.domain`
will raise an exception.
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
get mapped to `self.y[0]`.
By default, the lower bound is set to 0.

Fixes #249
sebp added a commit that referenced this issue Jun 11, 2023
Values outside the interval specified by `self.domain`
will raise an exception.
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
get mapped to `self.y[0]`.
By default, the lower bound is set to 0.

Fixes #249
sebp added a commit that referenced this issue Jun 11, 2023
Values outside the interval specified by `self.domain`
will raise an exception.
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
get mapped to `self.y[0]`.
By default, the lower bound is set to 0.

Fixes #249
@sebp sebp closed this as completed in #375 Jun 11, 2023
@sebp
Copy link
Owner

sebp commented Jun 11, 2023

#375 ensures that all estimators behave the same. It does not allow for evaluating the survival function (or CHF) beyond the maximum time point in the training data. Evaluating below the minimum time point is allowed, if the value is non-negative.

@cpoerschke Feel free to open a new issue/PR to discuss the clipping approach you proposed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants