Skip to content

Commit

Permalink
Allow specifying domain of StepFunction
Browse files Browse the repository at this point in the history
Values outside the interval specified by `self.domain`
will raise an exception.
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
get mapped to `self.y[0]`.
By default, the lower bound is set to 0.

Fixes #249
  • Loading branch information
sebp committed Jun 11, 2023
1 parent 4021b27 commit 4a7d6ae
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 15 deletions.
38 changes: 35 additions & 3 deletions sksurv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,46 @@ class StepFunction:
b : float, optional, default: 0.0
Constant offset term.
domain : tuple, optional
A tuple with two entries that sets the limits of the
domain of the step function.
If entry is `None`, use the first/last value of `x` as limit.
"""

def __init__(self, x, y, *, a=1.0, b=0.0):
def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)):
check_consistent_length(x, y)
self.x = x
self.y = y
self.a = a
self.b = b
domain_lower = self.x[0] if domain[0] is None else domain[0]
domain_upper = self.x[-1] if domain[1] is None else domain[1]
self._domain = (float(domain_lower), float(domain_upper))

@property
def domain(self):
"""Returns the domain of the function, that means
the range of values that the function accepts.
Returns
-------
lower_limit : float
Lower limit of domain.
upper_limit : float
Upper limit of domain.
"""
return self._domain

def __call__(self, x):
"""Evaluate step function.
Values outside the interval specified by `self.domain`
will raise an exception.
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
get mapped to `self.y[0]`.
Parameters
----------
x : float|array-like, shape=(n_values,)
Expand All @@ -63,8 +91,12 @@ def __call__(self, x):
x = np.atleast_1d(x)
if not np.isfinite(x).all():
raise ValueError("x must be finite")
if np.min(x) < self.x[0] or np.max(x) > self.x[-1]:
raise ValueError(f"x must be within [{self.x[0]:f}; {self.x[-1]:f}]")
if np.min(x) < self._domain[0] or np.max(x) > self.domain[1]:
raise ValueError(f"x must be within [{self.domain[0]:f}; {self.domain[1]:f}]")

# x is within the domain, but we need to account for self.domain[0] <= x < self.x[0]
x = np.clip(x, a_min=self.x[0], a_max=None)

i = np.searchsorted(self.x, x, side="left")
not_exact = self.x[i] != x
i[not_exact] -= 1
Expand Down
124 changes: 112 additions & 12 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from sksurv.functions import StepFunction
from sksurv.testing import all_survival_estimators


@pytest.fixture()
Expand All @@ -13,6 +14,25 @@ def a_step_function():
return f


@pytest.fixture()
def toy_data_exponential():
rnd = np.random.RandomState(2)
n_samples = 100
x = rnd.randn(n_samples, 2)
y = np.empty(n_samples, dtype=[("event", bool), ("time", float)])
y["time"] = rnd.exponential(scale=np.exp(x[:, 0]), size=n_samples)
y["event"] = rnd.binomial(1, 0.5, size=n_samples) == 1

# ensure at least 2 uncensored events exist
y["event"][:2] = True

# mark entry with largest time as censored
# see https://github.com/sebp/scikit-survival/issues/249
idxmax = np.argmax(y["time"])
y["event"][idxmax] = False
return x, y


class TestStepFunction:
@staticmethod
def test_exact(a_step_function):
Expand All @@ -26,18 +46,14 @@ def test_not_exact(a_step_function):
assert_array_equal(actual, a_step_function.y[:-1])

@staticmethod
def test_out_of_bounds(a_step_function):
eps = np.finfo(float).eps * 8
values = [
a_step_function.x[0] - 100,
a_step_function.x[-1] + 100,
a_step_function.x[0] - eps,
a_step_function.x[-1] + eps,
]

for v in values:
with pytest.raises(ValueError, match=r"x must be within \[0.0+; 9.0+\]"):
a_step_function(v)
@pytest.mark.parametrize("value", [-100, 100, -np.finfo(float).eps * 8, np.finfo(float).eps * 8])
def test_out_of_bounds(a_step_function, value):
v = value
if v > 0:
v += a_step_function.domain[1]

with pytest.raises(ValueError, match=r"x must be within \[0\.0+; 9\.0+\]"):
a_step_function(v)

@staticmethod
def test_not_finite(a_step_function, non_finite_value):
Expand All @@ -56,3 +72,87 @@ def test_equal(a_step_function):
assert a_step_function != different_step_function

assert a_step_function != x


@pytest.mark.parametrize(
"estimator_cls", [est for est in all_survival_estimators() if hasattr(est, "predict_cumulative_hazard_function")]
)
def test_predict_cumulative_hazard_function_range(estimator_cls, toy_data_exponential):
x, y = toy_data_exponential

estimator = estimator_cls()
if "fit_baseline_model" in estimator.get_params():
estimator.set_params(fit_baseline_model=True)
estimator.fit(x, y)

t_min = y["time"].min()
t_max = y["time"].max()
t_mid = (t_max - t_min) / 2.0

for fn in estimator.predict_cumulative_hazard_function(x):
v = fn(t_min)
assert np.isfinite(v)
assert v == 0

for fn in estimator.predict_cumulative_hazard_function(x):
v = fn(t_mid)
assert np.isfinite(v)
assert v >= 0

t_smaller_min = t_min / 2
for fn in estimator.predict_cumulative_hazard_function(x):
v = fn(t_smaller_min)
assert np.isfinite(v)
assert v == 0

for fn in estimator.predict_cumulative_hazard_function(x):
v = fn(t_max)
assert np.isfinite(v)
assert v >= 0

t_bigger_max = t_max + 1
for fn in estimator.predict_cumulative_hazard_function(x):
with pytest.raises(ValueError, match=r"x must be within \[[0-9.]+; [0-9.]+\]"):
fn(t_bigger_max)


@pytest.mark.parametrize(
"estimator_cls", [est for est in all_survival_estimators() if hasattr(est, "predict_survival_function")]
)
def test_predict_survival_function_range(estimator_cls, toy_data_exponential):
x, y = toy_data_exponential

estimator = estimator_cls()
if "fit_baseline_model" in estimator.get_params():
estimator.set_params(fit_baseline_model=True)
estimator.fit(x, y)

t_min = y["time"].min()
t_max = y["time"].max()
t_mid = (t_max - t_min) / 2.0

for fn in estimator.predict_survival_function(x):
v = fn(t_min)
assert np.isfinite(v)
assert v == 1

for fn in estimator.predict_survival_function(x):
v = fn(t_mid)
assert np.isfinite(v)
assert 0 <= v <= 1

t_smaller_min = t_min / 2
for fn in estimator.predict_survival_function(x):
v = fn(t_smaller_min)
assert np.isfinite(v)
assert v == 1

for fn in estimator.predict_survival_function(x):
v = fn(t_max)
assert np.isfinite(v)
assert 0 <= v <= 1

t_bigger_max = t_max + 1
for fn in estimator.predict_survival_function(x):
with pytest.raises(ValueError, match=r"x must be within \[[0-9.]+; [0-9.]+\]"):
fn(t_bigger_max)

0 comments on commit 4a7d6ae

Please sign in to comment.