diff --git a/neuralprocesses/dist/beta.py b/neuralprocesses/dist/beta.py index 8a830c75..31850831 100644 --- a/neuralprocesses/dist/beta.py +++ b/neuralprocesses/dist/beta.py @@ -90,7 +90,11 @@ def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: Masked): def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1): logz = B.logbeta(self.alpha, self.beta) logpdf = (self.alpha - 1) * B.log(x) + (self.beta - 1) * B.log(1 - x) - logz - return B.sum(mask * logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :]) + logpdf = logpdf * mask + if self.d == 0: + return logpdf + else: + return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :]) def __str__(self): return f"Beta({self.alpha}, {self.beta})" diff --git a/neuralprocesses/dist/dist.py b/neuralprocesses/dist/dist.py index 3af4c933..1ef6f47c 100644 --- a/neuralprocesses/dist/dist.py +++ b/neuralprocesses/dist/dist.py @@ -31,6 +31,7 @@ def sample(self, state: B.RandomState, dtype: B.DType, *shape): state (random state, optional): Random state. tensor: Samples of shape `(*shape, *d)` where typically `d = (*b, c, n)`. """ + print(type(self), type(state), type(dtype), *shape) raise NotImplementedError(f"{self} cannot be sampled.") @_dispatch diff --git a/neuralprocesses/dist/gamma.py b/neuralprocesses/dist/gamma.py new file mode 100644 index 00000000..52127026 --- /dev/null +++ b/neuralprocesses/dist/gamma.py @@ -0,0 +1,117 @@ +import lab as B +from matrix.shape import broadcast +from plum import parametric + +from .. import _dispatch +from ..aggregate import Aggregate +from ..mask import Masked +from .dist import AbstractDistribution, shape_batch + +__all__ = ["Gamma"] + + +@parametric +class Gamma(AbstractDistribution): + """Gamma distribution. + + Args: + k (tensor): Shape parameter. + scale (tensor): Scale parameter. + d (int): Dimensionality of the data. + + Attributes: + k (tensor): Shape parameter. + scale (tensor): Scale parameter. + d (int): Dimensionality of the data. + """ + + def __init__(self, k, scale, d): + self.k = k + self.scale = scale + self.d = d + + @property + def mean(self): + return B.multiply(self.k, self.scale) + + @property + def var(self): + return B.multiply(B.multiply(self.k, self.scale), self.scale) + + @_dispatch + def sample( + self: "Gamma[Aggregate, Aggregate, Aggregate]", + state: B.RandomState, + dtype: B.DType, + *shape, + ): + samples = [] + for ki, si, di in zip(self.k, self.scale, self.d): + state, sample = Gamma(ki, si, di).sample(state, dtype, *shape) + samples.append(sample) + return state, Aggregate(*samples) + + @_dispatch + def sample( + self: "Gamma[B.Numeric, B.Numeric, B.Int]", + state: B.RandomState, + dtype: B.DType, + *shape, + ): + return B.randgamma(state, dtype, *shape, alpha=self.k, scale=self.scale) + + @_dispatch + def logpdf(self: "Gamma[Aggregate, Aggregate, Aggregate]", x: Aggregate): + return sum( + [ + Gamma(ki, si, di).logpdf(xi) + for ki, si, di, xi in zip(self.k, self.scale, self.d, x) + ], + 0, + ) + + @_dispatch + def logpdf(self: "Gamma[B.Numeric, B.Numeric, B.Int]", x: Masked): + x, mask = x.y, x.mask + with B.on_device(self.k): + safe = B.to_active_device(B.one(B.dtype(self))) + # Make inputs safe. + x = mask * x + (1 - mask) * safe + # Run with safe inputs, and filter out the right logpdfs. + return self.logpdf(x, mask=mask) + + @_dispatch + def logpdf(self: "Gamma[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1): + logz = B.loggamma(self.k) + self.k * B.log(self.scale) + logpdf = (self.k - 1) * B.log(x) - x / self.scale - logz + logpdf = logpdf * mask + if self.d == 0: + return logpdf + else: + return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :]) + + def __str__(self): + return f"Gamma({self.k}, {self.scale})" + + def __repr__(self): + return f"Gamma({self.k!r}, {self.scale!r})" + + +@B.dtype.dispatch +def dtype(dist: Gamma): + return B.dtype(dist.k, dist.scale) + + +@shape_batch.dispatch +def shape_batch(dist: "Gamma[B.Numeric, B.Numeric, B.Int]"): + return B.shape_broadcast(dist.k, dist.scale)[: -dist.d] + + +@shape_batch.dispatch +def shape_batch(dist: "Gamma[Aggregate, Aggregate, Aggregate]"): + return broadcast( + *( + shape_batch(Gamma(ki, si, di)) + for ki, si, di in zip(dist.k, dist.scale, dist.d) + ) + ) diff --git a/tests/test_distribution.py b/tests/test_distribution.py index e29e145a..0b9e1fba 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -1,7 +1,12 @@ import lab as B +import scipy.stats as stats + +import torch +from neuralprocesses.dist.beta import Beta +from neuralprocesses.dist.gamma import Gamma from .test_architectures import check_prediction, generate_data -from .util import nps # noqa +from .util import approx, nps # noqa def test_transform_positive(nps): @@ -42,3 +47,35 @@ def test_transform_bounded(nps): # Check that predictions and samples satisfy the constraint. assert B.all(pred.mean > 10) and B.all(pred.mean < 11) assert B.all(pred.sample() > 10) and B.all(pred.sample() < 11) + + +def test_beta_correctness(): + """Test the correctness of the beta distribution.""" + beta = Beta(B.cast(torch.float64, 0.2), B.cast(torch.float64, 0.8), 0) + beta_ref = stats.beta(0.2, 0.8) + + sample = beta.sample() + approx(beta.logpdf(sample), beta_ref.logpdf(sample)) + approx(beta.mean, beta_ref.mean()) + approx(beta.var, beta_ref.var()) + + # Test dimensionality argument. + for d in range(4): + beta = Beta(beta.alpha, beta.beta, d) + assert beta.logpdf(beta.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d] + + +def test_gamma(): + """Test the correctness of the gamma distribution.""" + gamma = Gamma(B.cast(torch.float64, 2), B.cast(torch.float64, 0.8), 0) + gamma_ref = stats.gamma(2, scale=0.8) + + sample = gamma.sample() + approx(gamma.logpdf(sample), gamma_ref.logpdf(sample)) + approx(gamma.mean, gamma_ref.mean()) + approx(gamma.var, gamma_ref.var()) + + # Test dimensionality argument. + for d in range(4): + gamma = Gamma(gamma.k, gamma.scale, d) + assert gamma.logpdf(gamma.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d]