Skip to content

Commit

Permalink
Make nu sparsity conditionally independent.
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Sep 4, 2023
1 parent 870608b commit e3fd812
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
7 changes: 6 additions & 1 deletion src/jnotype/logistic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

from jnotype.logistic._binary_latent import sample_binary_codes
from jnotype.logistic._polyagamma import sample_intercepts_and_coefficients
from jnotype.logistic._structure import sample_structure, sample_gamma
from jnotype.logistic._structure import (
sample_structure,
sample_gamma,
sample_gamma_individual,
)

__all__ = [
"sample_structure",
"sample_gamma",
"sample_gamma_individual",
"sample_binary_codes",
"sample_intercepts_and_coefficients",
]
19 changes: 19 additions & 0 deletions src/jnotype/logistic/_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,22 @@ def sample_gamma(
posterior_b = prior_b + (n_all - n_successes)

return random.beta(key, posterior_a, posterior_b)


@jax.jit
def sample_gamma_individual(
key: random.PRNGKeyArray,
structure: Int[Array, "G K"],
prior_a: Float[Array, " K"],
prior_b: Float[Array, " K"],
) -> Float[Array, " K"]:
"""Samples the sparsity basing on the structure matrix,
but for each covariate separately.
"""
n_successes = jnp.sum(structure, axis=0)
n_all = structure.shape[0]

posterior_a = prior_a + n_successes
posterior_b = prior_b + (n_all - n_successes)

return random.beta(key, posterior_a, posterior_b)
25 changes: 15 additions & 10 deletions src/jnotype/pyramids/_sampler_fixed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Sampler for two-layer Bayesian pyramids
with fixed number of latent binary codes."""
from typing import Optional, Sequence, NewType
from typing import Optional, Sequence, Union, NewType

import jax
import jax.numpy as jnp
Expand All @@ -14,6 +14,7 @@
from jnotype.bmm import sample_bmm
from jnotype.logistic import (
sample_gamma,
sample_gamma_individual,
sample_structure,
sample_binary_codes,
sample_intercepts_and_coefficients,
Expand All @@ -39,19 +40,19 @@ def _single_sampling_step(
covariates: Float[Array, "points covariates"],
variances: Float[Array, " covariates"],
gamma: Float[Array, ""],
nu: Float[Array, ""],
nu: Float[Array, " observed_covariates"],
cluster_labels: Int[Array, " points"],
mixing: Float[Array, "n_binary_codes n_clusters"],
proportions: Float[Array, " n_clusters"],
# Priors
dirichlet_prior: Float[Array, " n_clusters"],
nu_prior_a: Union[float, Float[Array, " observed_covariates"]] = 1.0,
nu_prior_b: Union[float, Float[Array, " observed_covariates"]] = 1.0,
pseudoprior_variance: float = 0.01,
intercept_prior_mean: float = 0.0,
intercept_prior_variance: float = 1.0,
gamma_prior_a: float = 1.0,
gamma_prior_b: float = 1.0,
nu_prior_a: float = 1.0,
nu_prior_b: float = 1.0,
variances_prior_shape: float = 2.0,
variances_prior_scale: float = 1.0,
mixing_beta_prior: tuple[float, float] = (1.0, 1.0),
Expand Down Expand Up @@ -84,11 +85,10 @@ def _single_sampling_step(
# Sample structure and the sparsity
key, subkey_structure, subkey_gamma, subkey_nu = jax.random.split(key, 4)

n_observed_features = len(variances) - n_binary_codes
sparsity_vector = jnp.concatenate(
(
jnp.full(shape=(n_binary_codes,), fill_value=gamma),
jnp.full(shape=(n_observed_features,), fill_value=nu),
nu,
)
)
structure = sample_structure(
Expand All @@ -108,7 +108,7 @@ def _single_sampling_step(
prior_a=gamma_prior_a,
prior_b=gamma_prior_b,
)
nu = sample_gamma(
nu = sample_gamma_individual(
key=subkey_nu,
structure=structure[..., n_binary_codes:],
prior_a=nu_prior_a,
Expand Down Expand Up @@ -286,7 +286,7 @@ def dimensions(cls) -> _SplitSample:
"structure_latent": ["features", "latents"],
"structure_observed": ["features", "observed_covariates"],
"gamma": [], # Float, no named dimensions
"nu": [], # Float, no named dimensions
"nu": ["observed_covariates"],
"latent_variances": ["latents"],
"observed_variances": ["observed_covariates"],
"latent_traits": ["points", "latents"],
Expand Down Expand Up @@ -344,10 +344,15 @@ def _initialise_gamma(self) -> Float[Array, ""]:
)

def _initialise_nu(self) -> Float[Array, ""]:
return jax.random.beta(self._jax_rng.key, self._nu_prior[0], self._nu_prior[1])
return jax.random.beta(
self._jax_rng.key,
self._nu_prior[0],
self._nu_prior[1],
shape=(self._n_observed_covariates,),
)

def _initialise_structure(
self, gamma: Float[Array, ""], nu: Float[Array, ""]
self, gamma: Float[Array, ""], nu: Float[Array, " observed_covariates"]
) -> Int[Array, ""]:
"""Initialises the structure."""
n_outputs = self._observed_data.shape[1]
Expand Down
3 changes: 3 additions & 0 deletions src/jnotype/sampling/_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def dataset(self) -> xr.Dataset:
"sample": np.arange(len(self.samples), dtype=int),
} | self._coords

if not len(self.samples):
return xr.Dataset(coords=coords, attrs=attrs)

variables = {
label: (
self._coords_for_label(label),
Expand Down

0 comments on commit e3fd812

Please sign in to comment.