From 6ceabbd96050b49e15fb5f0c7c05bee3360a823e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Sun, 11 Jun 2023 10:06:59 +0200 Subject: [PATCH] Allow specifying domain of StepFunction Values outside the interval specified by `self.domain` will raise an exception. Values in `x` that are in the interval `[self.domain[0]; self.x[0]]` get mapped to `self.y[0]`. By default, the lower bound is set to 0. Fixes #249 --- sksurv/functions.py | 38 +++++++++++- tests/test_functions.py | 124 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 147 insertions(+), 15 deletions(-) diff --git a/sksurv/functions.py b/sksurv/functions.py index 5df79df1..ee978cc8 100644 --- a/sksurv/functions.py +++ b/sksurv/functions.py @@ -38,18 +38,46 @@ class StepFunction: b : float, optional, default: 0.0 Constant offset term. + + domain : tuple, optional + A tuple with two entries that sets the limits of the + domain of the step function. + If entry is `None`, use the first/last value of `x` as limit. """ - def __init__(self, x, y, *, a=1.0, b=0.0): + def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)): check_consistent_length(x, y) self.x = x self.y = y self.a = a self.b = b + domain_lower = self.x[0] if domain[0] is None else domain[0] + domain_upper = self.x[-1] if domain[1] is None else domain[1] + self._domain = (float(domain_lower), float(domain_upper)) + + @property + def domain(self): + """Returns the domain of the function, that means + the range of values that the function accepts. + + Returns + ------- + lower_limit : float + Lower limit of domain. + + upper_limit : float + Upper limit of domain. + """ + return self._domain def __call__(self, x): """Evaluate step function. + Values outside the interval specified by `self.domain` + will raise an exception. + Values in `x` that are in the interval `[self.domain[0]; self.x[0]]` + get mapped to `self.y[0]`. + Parameters ---------- x : float|array-like, shape=(n_values,) @@ -63,8 +91,12 @@ def __call__(self, x): x = np.atleast_1d(x) if not np.isfinite(x).all(): raise ValueError("x must be finite") - if np.min(x) < self.x[0] or np.max(x) > self.x[-1]: - raise ValueError(f"x must be within [{self.x[0]:f}; {self.x[-1]:f}]") + if np.min(x) < self._domain[0] or np.max(x) > self.domain[1]: + raise ValueError(f"x must be within [{self.domain[0]:f}; {self.domain[1]:f}]") + + # x is within the domain, but we need to account for self.domain[0] <= x < self.x[0] + x = np.clip(x, a_min=self.x[0], a_max=None) + i = np.searchsorted(self.x, x, side="left") not_exact = self.x[i] != x i[not_exact] -= 1 diff --git a/tests/test_functions.py b/tests/test_functions.py index 1467d094..7414ae14 100644 --- a/tests/test_functions.py +++ b/tests/test_functions.py @@ -3,6 +3,7 @@ import pytest from sksurv.functions import StepFunction +from sksurv.testing import all_survival_estimators @pytest.fixture() @@ -13,6 +14,25 @@ def a_step_function(): return f +@pytest.fixture() +def toy_data_exponential(): + rnd = np.random.RandomState(2) + n_samples = 100 + x = rnd.randn(n_samples, 2) + y = np.empty(n_samples, dtype=[("event", bool), ("time", float)]) + y["time"] = rnd.exponential(scale=np.exp(x[:, 0]), size=n_samples) + y["event"] = rnd.binomial(1, 0.5, size=n_samples) == 1 + + # ensure at least 2 uncensored events exist + y["event"][:2] = True + + # mark entry with largest time as censored + # see https://github.com/sebp/scikit-survival/issues/249 + idxmax = np.argmax(y["time"]) + y["event"][idxmax] = False + return x, y + + class TestStepFunction: @staticmethod def test_exact(a_step_function): @@ -26,18 +46,14 @@ def test_not_exact(a_step_function): assert_array_equal(actual, a_step_function.y[:-1]) @staticmethod - def test_out_of_bounds(a_step_function): - eps = np.finfo(float).eps * 8 - values = [ - a_step_function.x[0] - 100, - a_step_function.x[-1] + 100, - a_step_function.x[0] - eps, - a_step_function.x[-1] + eps, - ] - - for v in values: - with pytest.raises(ValueError, match=r"x must be within \[0.0+; 9.0+\]"): - a_step_function(v) + @pytest.mark.parametrize("value", [-100, 100, -np.finfo(float).eps * 8, np.finfo(float).eps * 8]) + def test_out_of_bounds(a_step_function, value): + v = value + if v > 0: + v += a_step_function.domain[1] + + with pytest.raises(ValueError, match=r"x must be within \[0\.0+; 9\.0+\]"): + a_step_function(v) @staticmethod def test_not_finite(a_step_function, non_finite_value): @@ -56,3 +72,87 @@ def test_equal(a_step_function): assert a_step_function != different_step_function assert a_step_function != x + + +@pytest.mark.parametrize( + "estimator_cls", [est for est in all_survival_estimators() if hasattr(est, "predict_cumulative_hazard_function")] +) +def test_predict_cumulative_hazard_function_range(estimator_cls, toy_data_exponential): + x, y = toy_data_exponential + + estimator = estimator_cls() + if "fit_baseline_model" in estimator.get_params(): + estimator.set_params(fit_baseline_model=True) + estimator.fit(x, y) + + t_min = y["time"].min() + t_max = y["time"].max() + t_mid = (t_max - t_min) / 2.0 + + for fn in estimator.predict_cumulative_hazard_function(x): + v = fn(t_min) + assert np.isfinite(v) + assert v == 0 + + for fn in estimator.predict_cumulative_hazard_function(x): + v = fn(t_mid) + assert np.isfinite(v) + assert v >= 0 + + t_smaller_min = t_min / 2 + for fn in estimator.predict_cumulative_hazard_function(x): + v = fn(t_smaller_min) + assert np.isfinite(v) + assert v == 0 + + for fn in estimator.predict_cumulative_hazard_function(x): + v = fn(t_max) + assert np.isfinite(v) + assert v >= 0 + + t_bigger_max = t_max + 1 + for fn in estimator.predict_cumulative_hazard_function(x): + with pytest.raises(ValueError, match=r"x must be within \[[0-9.]+; [0-9.]+\]"): + fn(t_bigger_max) + + +@pytest.mark.parametrize( + "estimator_cls", [est for est in all_survival_estimators() if hasattr(est, "predict_survival_function")] +) +def test_predict_survival_function_range(estimator_cls, toy_data_exponential): + x, y = toy_data_exponential + + estimator = estimator_cls() + if "fit_baseline_model" in estimator.get_params(): + estimator.set_params(fit_baseline_model=True) + estimator.fit(x, y) + + t_min = y["time"].min() + t_max = y["time"].max() + t_mid = (t_max - t_min) / 2.0 + + for fn in estimator.predict_survival_function(x): + v = fn(t_min) + assert np.isfinite(v) + assert v == 1 + + for fn in estimator.predict_survival_function(x): + v = fn(t_mid) + assert np.isfinite(v) + assert 0 <= v <= 1 + + t_smaller_min = t_min / 2 + for fn in estimator.predict_survival_function(x): + v = fn(t_smaller_min) + assert np.isfinite(v) + assert v == 1 + + for fn in estimator.predict_survival_function(x): + v = fn(t_max) + assert np.isfinite(v) + assert 0 <= v <= 1 + + t_bigger_max = t_max + 1 + for fn in estimator.predict_survival_function(x): + with pytest.raises(ValueError, match=r"x must be within \[[0-9.]+; [0-9.]+\]"): + fn(t_bigger_max)