Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to the primitives module #1940

Merged
merged 19 commits into from
Dec 23, 2024
Merged
16 changes: 6 additions & 10 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,22 @@
import numpyro.distributions as dist


def _non_centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
) -> Array:
def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=1.0))

return phi @ (spd * beta)


def _centered_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int]
) -> Array:
def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

return phi @ beta


def linear_approximation(
phi: ArrayLike, spd: ArrayLike, m: int | list[int], non_centered: bool = True
phi: Array, spd: Array, m: int, non_centered: bool = True
) -> Array:
"""
Linear approximation formula of the Hilbert space Gaussian process.
Expand All @@ -52,10 +48,10 @@ def linear_approximation(
1. Riutort-Mayol, G., Bürkner, PC., Andersen, M.R. et al. Practical Hilbert space
approximate Bayesian Gaussian processes for probabilistic programming. Stat Comput 33, 17 (2023).

:param ArrayLike phi: laplacian eigenfunctions
:param ArrayLike spd: square root of the diagonal of the spectral density evaluated at square
:param Array phi: laplacian eigenfunctions
:param Array spd: square root of the diagonal of the spectral density evaluated at square
root of the first `m` eigenvalues.
:param int | list[int] m: number of eigenfunctions in the approximation
:param int m: number of eigenfunctions in the approximation
:param bool non_centered: whether to use a non-centered parameterization
:return: The low-rank approximation linear model
:rtype: Array
Expand Down
19 changes: 10 additions & 9 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Union

import numpy as np
from numpy.typing import NDArray

import jax
from jax import device_get
Expand All @@ -25,7 +26,7 @@
]


def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _compute_chain_variance_stats(x: NDArray) -> tuple[NDArray, NDArray]:
# compute within-chain variance and variance estimator
# input has shape C x N x sample_shape
C, N = x.shape[:2]
Expand All @@ -41,7 +42,7 @@ def _compute_chain_variance_stats(x: np.ndarray) -> tuple[np.ndarray, np.ndarray
return var_within, var_estimator


def gelman_rubin(x: np.ndarray) -> np.ndarray:
def gelman_rubin(x: NDArray) -> NDArray:
"""
Computes R-hat over chains of samples ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand All @@ -60,7 +61,7 @@ def gelman_rubin(x: np.ndarray) -> np.ndarray:
return rhat


def split_gelman_rubin(x: np.ndarray) -> np.ndarray:
def split_gelman_rubin(x: NDArray) -> NDArray:
"""
Computes split R-hat over chains of samples ``x``, where the first dimension
of ``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -97,7 +98,7 @@ def _fft_next_fast_len(target: int) -> int:
target += 1


def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocorrelation(x: NDArray, axis: int = 0, bias: bool = True) -> NDArray:
"""
Computes the autocorrelation of samples at dimension ``axis``.

Expand Down Expand Up @@ -137,11 +138,11 @@ def autocorrelation(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarr
autocorr = autocorr / np.arange(N, 0.0, -1)

with np.errstate(invalid="ignore", divide="ignore"):
autocorr = autocorr / autocorr[..., :1]
autocorr = (autocorr / autocorr[..., :1]).astype(np.float64)
return np.swapaxes(autocorr, axis, -1)


def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarray:
def autocovariance(x: NDArray, axis: int = 0, bias: bool = True) -> NDArray:
"""
Computes the autocovariance of samples at dimension ``axis``.

Expand All @@ -154,7 +155,7 @@ def autocovariance(x: np.ndarray, axis: int = 0, bias: bool = True) -> np.ndarra
return autocorrelation(x, axis, bias) * x.var(axis=axis, keepdims=True)


def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
def effective_sample_size(x: NDArray, bias: bool = True) -> NDArray:
"""
Computes effective sample size of input ``x``, where the first dimension of
``x`` is chain dimension and the second dimension of ``x`` is draw dimension.
Expand Down Expand Up @@ -202,7 +203,7 @@ def effective_sample_size(x: np.ndarray, bias: bool = True) -> np.ndarray:
return n_eff


def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:
def hpdi(x: NDArray, prob: float = 0.90, axis: int = 0) -> NDArray:
"""
Computes "highest posterior density interval" (HPDI) which is the narrowest
interval with probability mass ``prob``.
Expand Down Expand Up @@ -285,7 +286,7 @@ def summary(


def print_summary(
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, NDArray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
62 changes: 50 additions & 12 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

from collections import OrderedDict
from contextlib import contextmanager
import functools
import inspect
from typing import Any, Protocol, runtime_checkable
import warnings

import numpy as np
Expand All @@ -37,6 +37,7 @@
from jax import lax, tree_util
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax.typing import ArrayLike

from numpyro.distributions.transforms import AbsTransform, ComposeTransform, Transform
from numpyro.distributions.util import (
Expand Down Expand Up @@ -270,7 +271,7 @@ def validate_args(self, strict: bool = True) -> None:
raise RuntimeError("Cannot validate arguments inside jitted code.")

@property
def batch_shape(self):
def batch_shape(self) -> tuple[int, ...]:
"""
Returns the shape over which the distribution parameters are batched.

Expand All @@ -280,7 +281,7 @@ def batch_shape(self):
return self._batch_shape

@property
def event_shape(self):
def event_shape(self) -> tuple[int, ...]:
"""
Returns the shape of a single sample from the distribution without
batching.
Expand All @@ -291,24 +292,24 @@ def event_shape(self):
return self._event_shape

@property
def event_dim(self):
def event_dim(self) -> int:
"""
:return: Number of dimensions of individual events.
:rtype: int
"""
return len(self.event_shape)

@property
def has_rsample(self):
def has_rsample(self) -> bool:
return set(self.reparametrized_params) == set(self.arg_constraints)

def rsample(self, key, sample_shape=()):
def rsample(self, key, sample_shape=()) -> ArrayLike:
if self.has_rsample:
return self.sample(key, sample_shape=sample_shape)

raise NotImplementedError

def shape(self, sample_shape=()):
def shape(self, sample_shape=()) -> tuple[int, ...]:
"""
The tensor shape of samples from this distribution.

Expand All @@ -323,7 +324,7 @@ def shape(self, sample_shape=()):
"""
return sample_shape + self.batch_shape + self.event_shape

def sample(self, key, sample_shape=()):
def sample(self, key: ArrayLike, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
"""
Returns a sample from the distribution having shape given by
`sample_shape + batch_shape + event_shape`. Note that when `sample_shape` is non-empty,
Expand Down Expand Up @@ -361,14 +362,14 @@ def log_prob(self, value):
raise NotImplementedError

@property
def mean(self):
def mean(self) -> ArrayLike:
"""
Mean of the distribution.
"""
raise NotImplementedError

@property
def variance(self):
def variance(self) -> ArrayLike:
"""
Variance of the distribution.
"""
Expand Down Expand Up @@ -540,7 +541,7 @@ def infer_shapes(cls, *args, **kwargs):
event_shape = ()
return batch_shape, event_shape

def cdf(self, value):
def cdf(self, value: ArrayLike) -> ArrayLike:
"""
The cumulative distribution function of this distribution.

Expand All @@ -549,7 +550,7 @@ def cdf(self, value):
"""
raise NotImplementedError

def icdf(self, q):
def icdf(self, q: ArrayLike) -> ArrayLike:
"""
The inverse cumulative distribution function of this distribution.

Expand All @@ -563,6 +564,43 @@ def is_discrete(self):
return self.support.is_discrete


@runtime_checkable
class DistributionLike(Protocol):
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
"""A protocol for typing distributions.

Used to type object of type numpyro.distributions.Distribution, funsor.Funsor
or tensorflow_probability.distributions.Distribution.
"""

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return super().__call__(*args, **kwargs)

@property
def batch_shape(self) -> tuple[int, ...]: ...

@property
def event_shape(self) -> tuple[int, ...]: ...

@property
def event_dim(self) -> int: ...

def sample(
self, key: ArrayLike, sample_shape: tuple[int, ...] = ()
) -> ArrayLike: ...

def log_prob(self, value: ArrayLike) -> ArrayLike: ...

@property
def mean(self) -> ArrayLike: ...

@property
def variance(self) -> ArrayLike: ...

def cdf(self, value: ArrayLike) -> ArrayLike: ...

def icdf(self, q: ArrayLike) -> ArrayLike: ...


class ExpandedDistribution(Distribution):
arg_constraints = {}
pytree_data_fields = ("base_dist",)
Expand Down
Loading
Loading