Skip to content

Commit

Permalink
Add gamma distribution and test gamma and beta dists
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Jun 10, 2024
1 parent 450477e commit d01e487
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 2 deletions.
6 changes: 5 additions & 1 deletion neuralprocesses/dist/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: Masked):
def logpdf(self: "Beta[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1):
logz = B.logbeta(self.alpha, self.beta)
logpdf = (self.alpha - 1) * B.log(x) + (self.beta - 1) * B.log(1 - x) - logz
return B.sum(mask * logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])
logpdf = logpdf * mask
if self.d == 0:
return logpdf
else:
return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])

def __str__(self):
return f"Beta({self.alpha}, {self.beta})"
Expand Down
1 change: 1 addition & 0 deletions neuralprocesses/dist/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def sample(self, state: B.RandomState, dtype: B.DType, *shape):
state (random state, optional): Random state.
tensor: Samples of shape `(*shape, *d)` where typically `d = (*b, c, n)`.
"""
print(type(self), type(state), type(dtype), *shape)
raise NotImplementedError(f"{self} cannot be sampled.")

@_dispatch
Expand Down
117 changes: 117 additions & 0 deletions neuralprocesses/dist/gamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import lab as B
from matrix.shape import broadcast
from plum import parametric

from .. import _dispatch
from ..aggregate import Aggregate
from ..mask import Masked
from .dist import AbstractDistribution, shape_batch

__all__ = ["Gamma"]


@parametric
class Gamma(AbstractDistribution):
"""Gamma distribution.
Args:
k (tensor): Shape parameter.
scale (tensor): Scale parameter.
d (int): Dimensionality of the data.
Attributes:
k (tensor): Shape parameter.
scale (tensor): Scale parameter.
d (int): Dimensionality of the data.
"""

def __init__(self, k, scale, d):
self.k = k
self.scale = scale
self.d = d

@property
def mean(self):
return B.multiply(self.k, self.scale)

@property
def var(self):
return B.multiply(B.multiply(self.k, self.scale), self.scale)

@_dispatch
def sample(
self: "Gamma[Aggregate, Aggregate, Aggregate]",
state: B.RandomState,
dtype: B.DType,
*shape,
):
samples = []
for ki, si, di in zip(self.k, self.scale, self.d):
state, sample = Gamma(ki, si, di).sample(state, dtype, *shape)
samples.append(sample)
return state, Aggregate(*samples)

@_dispatch
def sample(
self: "Gamma[B.Numeric, B.Numeric, B.Int]",
state: B.RandomState,
dtype: B.DType,
*shape,
):
return B.randgamma(state, dtype, *shape, alpha=self.k, scale=self.scale)

@_dispatch
def logpdf(self: "Gamma[Aggregate, Aggregate, Aggregate]", x: Aggregate):
return sum(
[
Gamma(ki, si, di).logpdf(xi)
for ki, si, di, xi in zip(self.k, self.scale, self.d, x)
],
0,
)

@_dispatch
def logpdf(self: "Gamma[B.Numeric, B.Numeric, B.Int]", x: Masked):
x, mask = x.y, x.mask
with B.on_device(self.k):
safe = B.to_active_device(B.one(B.dtype(self)))
# Make inputs safe.
x = mask * x + (1 - mask) * safe
# Run with safe inputs, and filter out the right logpdfs.
return self.logpdf(x, mask=mask)

@_dispatch
def logpdf(self: "Gamma[B.Numeric, B.Numeric, B.Int]", x: B.Numeric, *, mask=1):
logz = B.loggamma(self.k) + self.k * B.log(self.scale)
logpdf = (self.k - 1) * B.log(x) - x / self.scale - logz
logpdf = logpdf * mask
if self.d == 0:
return logpdf
else:
return B.sum(logpdf, axis=tuple(range(B.rank(logpdf)))[-self.d :])

def __str__(self):
return f"Gamma({self.k}, {self.scale})"

def __repr__(self):
return f"Gamma({self.k!r}, {self.scale!r})"


@B.dtype.dispatch
def dtype(dist: Gamma):
return B.dtype(dist.k, dist.scale)


@shape_batch.dispatch
def shape_batch(dist: "Gamma[B.Numeric, B.Numeric, B.Int]"):
return B.shape_broadcast(dist.k, dist.scale)[: -dist.d]


@shape_batch.dispatch
def shape_batch(dist: "Gamma[Aggregate, Aggregate, Aggregate]"):
return broadcast(
*(
shape_batch(Gamma(ki, si, di))
for ki, si, di in zip(dist.k, dist.scale, dist.d)
)
)
39 changes: 38 additions & 1 deletion tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import lab as B
import scipy.stats as stats

import torch
from neuralprocesses.dist.beta import Beta
from neuralprocesses.dist.gamma import Gamma

from .test_architectures import check_prediction, generate_data
from .util import nps # noqa
from .util import approx, nps # noqa


def test_transform_positive(nps):
Expand Down Expand Up @@ -42,3 +47,35 @@ def test_transform_bounded(nps):
# Check that predictions and samples satisfy the constraint.
assert B.all(pred.mean > 10) and B.all(pred.mean < 11)
assert B.all(pred.sample() > 10) and B.all(pred.sample() < 11)


def test_beta_correctness():
"""Test the correctness of the beta distribution."""
beta = Beta(B.cast(torch.float64, 0.2), B.cast(torch.float64, 0.8), 0)
beta_ref = stats.beta(0.2, 0.8)

sample = beta.sample()
approx(beta.logpdf(sample), beta_ref.logpdf(sample))
approx(beta.mean, beta_ref.mean())
approx(beta.var, beta_ref.var())

# Test dimensionality argument.
for d in range(4):
beta = Beta(beta.alpha, beta.beta, d)
assert beta.logpdf(beta.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d]


def test_gamma():
"""Test the correctness of the gamma distribution."""
gamma = Gamma(B.cast(torch.float64, 2), B.cast(torch.float64, 0.8), 0)
gamma_ref = stats.gamma(2, scale=0.8)

sample = gamma.sample()
approx(gamma.logpdf(sample), gamma_ref.logpdf(sample))
approx(gamma.mean, gamma_ref.mean())
approx(gamma.var, gamma_ref.var())

# Test dimensionality argument.
for d in range(4):
gamma = Gamma(gamma.k, gamma.scale, d)
assert gamma.logpdf(gamma.sample(1, 2, 3)).shape == (1, 2, 3)[: 3 - d]

0 comments on commit d01e487

Please sign in to comment.