diff --git a/src/jnotype/logistic/__init__.py b/src/jnotype/logistic/__init__.py index 6225f6c..7c9cb94 100644 --- a/src/jnotype/logistic/__init__.py +++ b/src/jnotype/logistic/__init__.py @@ -2,11 +2,16 @@ from jnotype.logistic._binary_latent import sample_binary_codes from jnotype.logistic._polyagamma import sample_intercepts_and_coefficients -from jnotype.logistic._structure import sample_structure, sample_gamma +from jnotype.logistic._structure import ( + sample_structure, + sample_gamma, + sample_gamma_individual, +) __all__ = [ "sample_structure", "sample_gamma", + "sample_gamma_individual", "sample_binary_codes", "sample_intercepts_and_coefficients", ] diff --git a/src/jnotype/logistic/_structure.py b/src/jnotype/logistic/_structure.py index e3c4190..a28470b 100644 --- a/src/jnotype/logistic/_structure.py +++ b/src/jnotype/logistic/_structure.py @@ -173,3 +173,22 @@ def sample_gamma( posterior_b = prior_b + (n_all - n_successes) return random.beta(key, posterior_a, posterior_b) + + +@jax.jit +def sample_gamma_individual( + key: random.PRNGKeyArray, + structure: Int[Array, "G K"], + prior_a: Float[Array, " K"], + prior_b: Float[Array, " K"], +) -> Float[Array, " K"]: + """Samples the sparsity basing on the structure matrix, + but for each covariate separately. + """ + n_successes = jnp.sum(structure, axis=0) + n_all = structure.shape[0] + + posterior_a = prior_a + n_successes + posterior_b = prior_b + (n_all - n_successes) + + return random.beta(key, posterior_a, posterior_b) diff --git a/src/jnotype/pyramids/_sampler_fixed.py b/src/jnotype/pyramids/_sampler_fixed.py index 2d8eff6..0516666 100644 --- a/src/jnotype/pyramids/_sampler_fixed.py +++ b/src/jnotype/pyramids/_sampler_fixed.py @@ -1,6 +1,6 @@ """Sampler for two-layer Bayesian pyramids with fixed number of latent binary codes.""" -from typing import Optional, Sequence, NewType +from typing import Optional, Sequence, Union, NewType import jax import jax.numpy as jnp @@ -14,6 +14,7 @@ from jnotype.bmm import sample_bmm from jnotype.logistic import ( sample_gamma, + sample_gamma_individual, sample_structure, sample_binary_codes, sample_intercepts_and_coefficients, @@ -39,19 +40,19 @@ def _single_sampling_step( covariates: Float[Array, "points covariates"], variances: Float[Array, " covariates"], gamma: Float[Array, ""], - nu: Float[Array, ""], + nu: Float[Array, " observed_covariates"], cluster_labels: Int[Array, " points"], mixing: Float[Array, "n_binary_codes n_clusters"], proportions: Float[Array, " n_clusters"], # Priors dirichlet_prior: Float[Array, " n_clusters"], + nu_prior_a: Union[float, Float[Array, " observed_covariates"]] = 1.0, + nu_prior_b: Union[float, Float[Array, " observed_covariates"]] = 1.0, pseudoprior_variance: float = 0.01, intercept_prior_mean: float = 0.0, intercept_prior_variance: float = 1.0, gamma_prior_a: float = 1.0, gamma_prior_b: float = 1.0, - nu_prior_a: float = 1.0, - nu_prior_b: float = 1.0, variances_prior_shape: float = 2.0, variances_prior_scale: float = 1.0, mixing_beta_prior: tuple[float, float] = (1.0, 1.0), @@ -84,11 +85,10 @@ def _single_sampling_step( # Sample structure and the sparsity key, subkey_structure, subkey_gamma, subkey_nu = jax.random.split(key, 4) - n_observed_features = len(variances) - n_binary_codes sparsity_vector = jnp.concatenate( ( jnp.full(shape=(n_binary_codes,), fill_value=gamma), - jnp.full(shape=(n_observed_features,), fill_value=nu), + nu, ) ) structure = sample_structure( @@ -108,7 +108,7 @@ def _single_sampling_step( prior_a=gamma_prior_a, prior_b=gamma_prior_b, ) - nu = sample_gamma( + nu = sample_gamma_individual( key=subkey_nu, structure=structure[..., n_binary_codes:], prior_a=nu_prior_a, @@ -286,7 +286,7 @@ def dimensions(cls) -> _SplitSample: "structure_latent": ["features", "latents"], "structure_observed": ["features", "observed_covariates"], "gamma": [], # Float, no named dimensions - "nu": [], # Float, no named dimensions + "nu": ["observed_covariates"], "latent_variances": ["latents"], "observed_variances": ["observed_covariates"], "latent_traits": ["points", "latents"], @@ -344,10 +344,15 @@ def _initialise_gamma(self) -> Float[Array, ""]: ) def _initialise_nu(self) -> Float[Array, ""]: - return jax.random.beta(self._jax_rng.key, self._nu_prior[0], self._nu_prior[1]) + return jax.random.beta( + self._jax_rng.key, + self._nu_prior[0], + self._nu_prior[1], + shape=(self._n_observed_covariates,), + ) def _initialise_structure( - self, gamma: Float[Array, ""], nu: Float[Array, ""] + self, gamma: Float[Array, ""], nu: Float[Array, " observed_covariates"] ) -> Int[Array, ""]: """Initialises the structure.""" n_outputs = self._observed_data.shape[1] diff --git a/src/jnotype/sampling/_chunker.py b/src/jnotype/sampling/_chunker.py index 4c3506f..119d1cc 100644 --- a/src/jnotype/sampling/_chunker.py +++ b/src/jnotype/sampling/_chunker.py @@ -83,6 +83,9 @@ def dataset(self) -> xr.Dataset: "sample": np.arange(len(self.samples), dtype=int), } | self._coords + if not len(self.samples): + return xr.Dataset(coords=coords, attrs=attrs) + variables = { label: ( self._coords_for_label(label),