Skip to content

Commit

Permalink
Add support argument to Delta distribution.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Jan 3, 2025
1 parent 6ae76ea commit afecd82
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
23 changes: 15 additions & 8 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,18 +1213,25 @@ class Delta(Distribution):
"log_density": constraints.real,
}
reparametrized_params = ["v", "log_density"]

def __init__(self, v=0.0, log_density=0.0, event_dim=0, *, validate_args=None):
if event_dim > jnp.ndim(v):
pytree_aux_fields = ("v", "log_density", "_support")

def __init__(
self, v=0.0, log_density=0.0, event_dim=0, support=None, *, validate_args=None
):
if support is None:
support = constraints.real
if event_dim:
support = constraints.independent(support, event_dim)
if support.event_dim > jnp.ndim(v):
raise ValueError(
"Expected event_dim <= v.dim(), actual {} vs {}".format(
event_dim, jnp.ndim(v)
)
"Expected event_dim + support.event_dim <= v.dim(), actual "
f"{support.event_dim} vs {jnp.ndim(v)}."
)
batch_dim = jnp.ndim(v) - event_dim
batch_dim = jnp.ndim(v) - support.event_dim
batch_shape = jnp.shape(v)[:batch_dim]
event_shape = jnp.shape(v)[batch_dim:]
self.v = v
self._support = support
# NB: following Pyro implementation, log_density should be broadcasted to batch_shape
self.log_density = promote_shapes(log_density, shape=batch_shape)[0]
super(Delta, self).__init__(
Expand All @@ -1233,7 +1240,7 @@ def __init__(self, v=0.0, log_density=0.0, event_dim=0, *, validate_args=None):

@constraints.dependent_property
def support(self):
return constraints.independent(constraints.real, self.event_dim)
return self._support

def sample(self, key, sample_shape=()):
if not sample_shape:
Expand Down
22 changes: 20 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,20 @@ def get_sp_dist(jax_dist):
T(dist.Delta, 1),
T(dist.Delta, np.array([0.0, 2.0])),
T(dist.Delta, np.array([0.0, 2.0]), np.array([-2.0, -4.0])),
T(
dist.Delta,
np.array([0.0, 2.0]),
-3.0,
0,
constraints.real_vector,
),
T(
dist.Delta,
np.array([[1.0, 1.5], [2.3, 1.2]])[..., None, None] * np.eye(3),
np.array([-2.0, -4.0]),
1,
constraints.positive_definite,
),
T(dist.DirichletMultinomial, np.array([1.0, 2.0, 3.9]), 10),
T(dist.DirichletMultinomial, np.array([0.2, 0.7, 1.1]), np.array([5, 5])),
T(dist.GammaPoisson, 2.0, 2.0),
Expand Down Expand Up @@ -1446,7 +1460,7 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params):
"LKJCholesky",
"_SparseCAR",
"ZeroSumNormal",
):
) or (jax_dist.__name__ == "Delta" and len(params) > 2):
pytest.xfail(reason="non-jittable params")

rng_key = random.PRNGKey(0)
Expand Down Expand Up @@ -1882,7 +1896,11 @@ def fn(*args):
continue
if jax_dist is dist.DoublyTruncatedPowerLaw and i != 0:
continue
if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64):
if (
params[i] is None
or isinstance(params[i], constraints.Constraint)
or jnp.result_type(params[i]) in (jnp.int32, jnp.int64)
):
continue
actual_grad = jax.grad(fn, i)(*params)
args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
Expand Down

0 comments on commit afecd82

Please sign in to comment.