-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Prototype of the factor analysis with CSP prior (#20)
* Add calculation of the number of active traits. * Simulation of factor analysis * Add Gibbs sampling steps. * High-level inference utilities
- Loading branch information
1 parent
a5248bf
commit f90b967
Showing
5 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""Factor analysis of Legremanti et al. (2020). | ||
It is a prototype, without enough test coverage! | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""Sampling steps for all variables, apart from variances | ||
attributed to latent traits, which are sampled with CSP module.""" | ||
from typing import Callable | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import random | ||
|
||
|
||
from jaxtyping import Float, Array | ||
|
||
|
||
def _v_j_generator( | ||
theta: Float[Array, " traits"], | ||
eta: Float[Array, "points traits"], | ||
) -> Callable: | ||
"""Generates a function which is used to sample | ||
covariance matrix, used to sample mixing matrix.""" | ||
# Both arrays have shape (traits, traits) | ||
d_inv = jnp.diag(jnp.reciprocal(theta)) | ||
eta_eta = eta.T @ eta | ||
|
||
def v_j(sigma2: float) -> Float[Array, "traits1 traits2"]: | ||
return jnp.linalg.inv(d_inv + eta_eta / sigma2) | ||
|
||
return v_j | ||
|
||
|
||
def gibbs_sample_mixing( | ||
key, | ||
theta: Float[Array, " traits"], | ||
sigma2: Float[Array, " observed"], | ||
eta: Float[Array, "points traits"], | ||
Y: Float[Array, "points observed"], | ||
) -> Float[Array, "observed traits"]: | ||
"""Samples the mixing matrix. | ||
Args: | ||
key: JAX PRNG key | ||
theta: variances attributed to each latent trait | ||
sigma2: noise variance for each observed dimension | ||
eta: Latent traits | ||
Y: observed data | ||
Returns: | ||
mixing matrix, shape (observed, traits) | ||
""" | ||
|
||
# Shape (observed, traits, traits) | ||
V = jax.vmap(_v_j_generator(theta, eta))(sigma2) | ||
|
||
temp1 = jnp.einsum("phk,nk->pnh", V, eta) | ||
temp2 = jnp.einsum("pnh,np->ph", temp1, Y) | ||
mu = temp2 / sigma2[:, None] # Shape (observed, traits) | ||
|
||
subkeys = random.split(key, V.shape[0]) | ||
return jax.vmap(random.multivariate_normal, in_axes=(0, 0, 0))(subkeys, mu, V) | ||
|
||
|
||
def gibbs_sample_traits( | ||
key, | ||
lambd: Float[Array, "observed traits"], | ||
sigma2: Float[Array, " observed"], | ||
Y: Float[Array, "points observed"], | ||
) -> Float[Array, "points traits"]: | ||
"""Samples the latent traits.""" | ||
# Shape (traits, traits) | ||
Sigma_inv = jnp.diag(jnp.reciprocal(sigma2)) | ||
Id = jnp.eye(lambd.shape[1]) | ||
V = jnp.linalg.inv(Id + lambd.T @ Sigma_inv @ lambd) | ||
|
||
# Shape (points, traits) | ||
mu = (V @ lambd.T @ Sigma_inv @ Y.T).T | ||
|
||
n_points = mu.shape[0] | ||
subkeys = random.split(key, n_points) | ||
|
||
return jax.vmap(random.multivariate_normal, in_axes=(0, 0, None))(subkeys, mu, V) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
"""Sampling from the posterior distribution.""" | ||
from typing import Callable | ||
|
||
from jaxtyping import Float, Array | ||
import jax | ||
import jax.numpy as jnp | ||
from jax import random | ||
|
||
from jnotype._csp import sample_csp_prior, sample_csp_gibbs | ||
from jnotype._factor_analysis._gibbs_backend import ( | ||
gibbs_sample_mixing, | ||
gibbs_sample_traits, | ||
) | ||
|
||
Sample = dict | ||
|
||
|
||
def initial_sample( | ||
key, | ||
Y: Float[Array, "points observed"], | ||
sigma2: Float[Array, " observed"], | ||
max_traits: int, | ||
expected_occupied: float, | ||
prior_shape: float, | ||
prior_scale: float, | ||
) -> Sample: | ||
"""Initializes the first sample. | ||
Args: | ||
key: JAX PRNG key | ||
Y: observed data | ||
sigma2: noise variance for each observed dimension, | ||
which is assumed to be known | ||
max_traits: maximum number of traits | ||
expected_occupied: expected number of traits | ||
prior_shape: shape parameter of the trait variances prior | ||
prior_scale: scale parameter of the trait variances prior | ||
Note: | ||
Currently `sigma2` is not sampled, but just copied from the input argument. | ||
""" | ||
n_points, n_observed = Y.shape[0], Y.shape[1] | ||
key, *subkeys = random.split(key, 10) | ||
|
||
csp_sample = sample_csp_prior( | ||
key=subkeys[0], | ||
k=max_traits, | ||
expected_occupied=expected_occupied, | ||
prior_shape=prior_shape, | ||
prior_scale=prior_scale, | ||
) | ||
|
||
eta = random.normal(subkeys[1], shape=(n_points, max_traits)) | ||
# Make Lambda smaller and smaller in the initial sample | ||
variances_initial = jnp.exp(-jnp.arange(0, max_traits)) | ||
lambd = ( | ||
random.normal(subkeys[2], shape=(n_observed, max_traits)) | ||
* jnp.sqrt(variances_initial)[None, :] | ||
) | ||
|
||
return { | ||
"eta": eta, | ||
"lambda": lambd, | ||
"sigma2": sigma2, | ||
"csp": csp_sample, | ||
} | ||
|
||
|
||
def generate_sampling_step( | ||
Y, | ||
csp_shape: float = 2.0, | ||
csp_scale: float = 2.0, | ||
csp_theta_inf: float = 0.01, | ||
csp_expected: float = 5.0, | ||
jit_it: bool = True, | ||
) -> Callable: | ||
"""Creates a Gibbs Markov kernel | ||
of signature | ||
(key, sample) -> sample | ||
Args: | ||
Y: observed data | ||
csp_shape: shape parameter of the trait variances prior | ||
csp_scale: scale parameter of the trait variances prior | ||
csp_theta_inf: trait variance for inactive (shrunk) traits | ||
csp_expected: expected number of traits | ||
jit_it: whether to JIT-compile the kernel | ||
""" | ||
|
||
def _sample_gibbs( | ||
key, | ||
sample, | ||
) -> dict: | ||
subkeys = random.split(key, 4) | ||
# Sample lambda | ||
lambd = gibbs_sample_mixing( | ||
subkeys[0], | ||
theta=sample["csp"]["variance"], | ||
sigma2=sample["sigma2"], | ||
eta=sample["eta"], | ||
Y=Y, | ||
) | ||
# We don't sample sigma2, just copy it | ||
sigma2 = sample["sigma2"] | ||
|
||
# Sample eta | ||
eta = gibbs_sample_traits( | ||
subkeys[2], | ||
lambd=lambd, | ||
sigma2=sigma2, | ||
Y=Y, | ||
) | ||
|
||
# Sample CSP parameters | ||
csp = sample_csp_gibbs( | ||
subkeys[3], | ||
coefficients=lambd, | ||
structure=jnp.ones_like(lambd), | ||
omega=sample["csp"]["omega"], | ||
expected_occupied=csp_expected, | ||
prior_shape=csp_shape, | ||
prior_scale=csp_scale, | ||
theta_inf=csp_theta_inf, | ||
) | ||
|
||
return { | ||
"eta": eta, | ||
"lambda": lambd, | ||
"sigma2": sigma2, | ||
"csp": csp, | ||
} | ||
|
||
if jit_it: | ||
return jax.jit(_sample_gibbs) | ||
else: | ||
return _sample_gibbs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Simulate data sets.""" | ||
from jaxtyping import Float, Array | ||
|
||
import jax.numpy as jnp | ||
from jax import random | ||
|
||
|
||
def sample_observations( | ||
key, | ||
lambd: Float[Array, "observed latent"], | ||
eta: Float[Array, "points latent"], | ||
sigma2: Float[Array, " observed"], | ||
) -> Float[Array, "points observed"]: | ||
"""Samples observations. | ||
Args: | ||
key: JAX PRNG key | ||
lambd: Mixing matrix | ||
eta: Latent traits | ||
sigma2: noise variance for each observed dimension | ||
Returns: | ||
Y matrix, shape (n_points, n_observed) | ||
""" | ||
# Shape (N, P) | ||
N = eta.shape[0] | ||
P = lambd.shape[0] | ||
noise = random.normal(key, shape=(N, P)) * jnp.sqrt(sigma2)[None, :] | ||
return jnp.einsum("PH,NH -> NP", lambd, eta) + noise | ||
|
||
|
||
def sample_latent(key, points: int, latent: int) -> Float[Array, "points latent"]: | ||
"""Samples latent traits. | ||
Args: | ||
key: JAX PRNG key | ||
points: Number of points | ||
latent: Number of latent traits | ||
""" | ||
return random.normal(key, shape=(points, latent)) | ||
|
||
|
||
def sample_mixing( | ||
key, observed: int, theta: Float[Array, " latent"] | ||
) -> Float[Array, "observed latent"]: | ||
"""Samples the mixing matrix. | ||
Args: | ||
key: JAX PRNG key | ||
observed: Number of observed dimensions | ||
theta: variance of factors attributed to each latent trait | ||
""" | ||
latent = theta.shape[0] | ||
lambd = random.normal(key, shape=(observed, latent)) | ||
return lambd * jnp.sqrt(theta)[None, ...] | ||
|
||
|
||
def covariance_from_mixing( | ||
lambd: Float[Array, "observed latent"], | ||
sigma2: Float[Array, " observed"], | ||
) -> Float[Array, "observed observed"]: | ||
"""Compute the covariance matrix of the observed variables | ||
using the mixing matrix | ||
and additional variance for each observed dimension.""" | ||
return jnp.einsum("ph,qh -> pq", lambd, lambd) + jnp.diag(sigma2) |