From 8e1d9b258a3e4be3076ee2ba81d832e45fba792e Mon Sep 17 00:00:00 2001 From: Raphael Erik Hviding Date: Thu, 19 Dec 2024 00:55:40 +0100 Subject: [PATCH] Implement ICDF Methods for Truncated Distributions (#1938) * Add icdf methods to generic truncated distributions * Restricted icdf function to [0,1] --- numpyro/distributions/truncated.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index adcd7f2ad..78db31d6a 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -67,11 +67,15 @@ def sample(self, key, sample_shape=()): finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) + return self.icdf(u) + + def icdf(self, q): loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) - return (1 - sign) * loc + sign * self.base_dist.icdf( - (1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high + ppf = (1 - sign) * loc + sign * self.base_dist.icdf( + (1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high ) + return jnp.where(q < 0, jnp.nan, ppf) @validate_sample def log_prob(self, value): @@ -138,7 +142,11 @@ def sample(self, key, sample_shape=()): finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) - return self.base_dist.icdf(u * self._cdf_at_high) + return self.icdf(u) + + def icdf(self, q): + ppf = self.base_dist.icdf(q * self._cdf_at_high) + return jnp.where(q > 1, jnp.nan, ppf) @validate_sample def log_prob(self, value): @@ -235,19 +243,22 @@ def sample(self, key, sample_shape=()): finfo = jnp.finfo(dtype) minval = finfo.tiny u = random.uniform(key, shape=sample_shape + self.batch_shape, minval=minval) + return self.icdf(u) + def icdf(self, q): # NB: we use a more numerically stable formula for a symmetric base distribution - # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * u) = icdf[(1 - u) * cdf(low) + u * cdf(high)] + # A = icdf(cdf(low) + (cdf(high) - cdf(low)) * q) = icdf[(1 - q) * cdf(low) + q * cdf(high)] # will suffer by precision issues when low is large; # If low < loc: - # A = icdf[(1 - u) * cdf(low) + u * cdf(high)] + # A = icdf[(1 - q) * cdf(low) + q * cdf(high)] # Else - # A = 2 * loc - icdf[(1 - u) * cdf(2*loc-low)) + u * cdf(2*loc - high)] + # A = 2 * loc - icdf[(1 - q) * cdf(2*loc-low)) + q * cdf(2*loc - high)] loc = self.base_dist.loc sign = jnp.where(loc >= self.low, 1.0, -1.0) - return (1 - sign) * loc + sign * self.base_dist.icdf( - clamp_probs((1 - u) * self._tail_prob_at_low + u * self._tail_prob_at_high) + ppf = (1 - sign) * loc + sign * self.base_dist.icdf( + clamp_probs((1 - q) * self._tail_prob_at_low + q * self._tail_prob_at_high) ) + return jnp.where(jnp.logical_or(q < 0, q > 1), jnp.nan, ppf) @validate_sample def log_prob(self, value):