Skip to content

Commit

Permalink
Prototype of the factor analysis with CSP prior (#20)
Browse files Browse the repository at this point in the history
* Add calculation of the number of active traits.

* Simulation of factor analysis

* Add Gibbs sampling steps.

* High-level inference utilities
  • Loading branch information
pawel-czyz authored Aug 19, 2023
1 parent a5248bf commit f90b967
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/jnotype/_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def _select_variances_active(
return jnp.where(inactive, variances_inactive, variances_active)


def compute_active_traits(indicators: Int[Array, " codes"]) -> Int[Array, " codes"]:
"""Annotates with 1 which traits are active."""
active = jnp.greater(indicators, jnp.arange(indicators.shape[0]))
return jnp.asarray(active, dtype=int)


def _sample_variances_conditioned_on_indicators(
key,
indicators: Int[Array, " codes"],
Expand Down Expand Up @@ -307,6 +313,8 @@ def sample_csp_prior(
"nu": nus,
"omega": omega,
"indicators": indicators,
"active_traits": compute_active_traits(indicators),
"n_active": jnp.sum(compute_active_traits(indicators)),
}


Expand Down Expand Up @@ -351,4 +359,6 @@ def sample_csp_gibbs(
"nu": nus,
"omega": omega,
"indicators": indicators,
"active_traits": compute_active_traits(indicators),
"n_active": jnp.sum(compute_active_traits(indicators)),
}
4 changes: 4 additions & 0 deletions src/jnotype/_factor_analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Factor analysis of Legremanti et al. (2020).
It is a prototype, without enough test coverage!
"""
78 changes: 78 additions & 0 deletions src/jnotype/_factor_analysis/_gibbs_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Sampling steps for all variables, apart from variances
attributed to latent traits, which are sampled with CSP module."""
from typing import Callable

import jax
import jax.numpy as jnp
from jax import random


from jaxtyping import Float, Array


def _v_j_generator(
theta: Float[Array, " traits"],
eta: Float[Array, "points traits"],
) -> Callable:
"""Generates a function which is used to sample
covariance matrix, used to sample mixing matrix."""
# Both arrays have shape (traits, traits)
d_inv = jnp.diag(jnp.reciprocal(theta))
eta_eta = eta.T @ eta

def v_j(sigma2: float) -> Float[Array, "traits1 traits2"]:
return jnp.linalg.inv(d_inv + eta_eta / sigma2)

return v_j


def gibbs_sample_mixing(
key,
theta: Float[Array, " traits"],
sigma2: Float[Array, " observed"],
eta: Float[Array, "points traits"],
Y: Float[Array, "points observed"],
) -> Float[Array, "observed traits"]:
"""Samples the mixing matrix.
Args:
key: JAX PRNG key
theta: variances attributed to each latent trait
sigma2: noise variance for each observed dimension
eta: Latent traits
Y: observed data
Returns:
mixing matrix, shape (observed, traits)
"""

# Shape (observed, traits, traits)
V = jax.vmap(_v_j_generator(theta, eta))(sigma2)

temp1 = jnp.einsum("phk,nk->pnh", V, eta)
temp2 = jnp.einsum("pnh,np->ph", temp1, Y)
mu = temp2 / sigma2[:, None] # Shape (observed, traits)

subkeys = random.split(key, V.shape[0])
return jax.vmap(random.multivariate_normal, in_axes=(0, 0, 0))(subkeys, mu, V)


def gibbs_sample_traits(
key,
lambd: Float[Array, "observed traits"],
sigma2: Float[Array, " observed"],
Y: Float[Array, "points observed"],
) -> Float[Array, "points traits"]:
"""Samples the latent traits."""
# Shape (traits, traits)
Sigma_inv = jnp.diag(jnp.reciprocal(sigma2))
Id = jnp.eye(lambd.shape[1])
V = jnp.linalg.inv(Id + lambd.T @ Sigma_inv @ lambd)

# Shape (points, traits)
mu = (V @ lambd.T @ Sigma_inv @ Y.T).T

n_points = mu.shape[0]
subkeys = random.split(key, n_points)

return jax.vmap(random.multivariate_normal, in_axes=(0, 0, None))(subkeys, mu, V)
136 changes: 136 additions & 0 deletions src/jnotype/_factor_analysis/_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Sampling from the posterior distribution."""
from typing import Callable

from jaxtyping import Float, Array
import jax
import jax.numpy as jnp
from jax import random

from jnotype._csp import sample_csp_prior, sample_csp_gibbs
from jnotype._factor_analysis._gibbs_backend import (
gibbs_sample_mixing,
gibbs_sample_traits,
)

Sample = dict


def initial_sample(
key,
Y: Float[Array, "points observed"],
sigma2: Float[Array, " observed"],
max_traits: int,
expected_occupied: float,
prior_shape: float,
prior_scale: float,
) -> Sample:
"""Initializes the first sample.
Args:
key: JAX PRNG key
Y: observed data
sigma2: noise variance for each observed dimension,
which is assumed to be known
max_traits: maximum number of traits
expected_occupied: expected number of traits
prior_shape: shape parameter of the trait variances prior
prior_scale: scale parameter of the trait variances prior
Note:
Currently `sigma2` is not sampled, but just copied from the input argument.
"""
n_points, n_observed = Y.shape[0], Y.shape[1]
key, *subkeys = random.split(key, 10)

csp_sample = sample_csp_prior(
key=subkeys[0],
k=max_traits,
expected_occupied=expected_occupied,
prior_shape=prior_shape,
prior_scale=prior_scale,
)

eta = random.normal(subkeys[1], shape=(n_points, max_traits))
# Make Lambda smaller and smaller in the initial sample
variances_initial = jnp.exp(-jnp.arange(0, max_traits))
lambd = (
random.normal(subkeys[2], shape=(n_observed, max_traits))
* jnp.sqrt(variances_initial)[None, :]
)

return {
"eta": eta,
"lambda": lambd,
"sigma2": sigma2,
"csp": csp_sample,
}


def generate_sampling_step(
Y,
csp_shape: float = 2.0,
csp_scale: float = 2.0,
csp_theta_inf: float = 0.01,
csp_expected: float = 5.0,
jit_it: bool = True,
) -> Callable:
"""Creates a Gibbs Markov kernel
of signature
(key, sample) -> sample
Args:
Y: observed data
csp_shape: shape parameter of the trait variances prior
csp_scale: scale parameter of the trait variances prior
csp_theta_inf: trait variance for inactive (shrunk) traits
csp_expected: expected number of traits
jit_it: whether to JIT-compile the kernel
"""

def _sample_gibbs(
key,
sample,
) -> dict:
subkeys = random.split(key, 4)
# Sample lambda
lambd = gibbs_sample_mixing(
subkeys[0],
theta=sample["csp"]["variance"],
sigma2=sample["sigma2"],
eta=sample["eta"],
Y=Y,
)
# We don't sample sigma2, just copy it
sigma2 = sample["sigma2"]

# Sample eta
eta = gibbs_sample_traits(
subkeys[2],
lambd=lambd,
sigma2=sigma2,
Y=Y,
)

# Sample CSP parameters
csp = sample_csp_gibbs(
subkeys[3],
coefficients=lambd,
structure=jnp.ones_like(lambd),
omega=sample["csp"]["omega"],
expected_occupied=csp_expected,
prior_shape=csp_shape,
prior_scale=csp_scale,
theta_inf=csp_theta_inf,
)

return {
"eta": eta,
"lambda": lambd,
"sigma2": sigma2,
"csp": csp,
}

if jit_it:
return jax.jit(_sample_gibbs)
else:
return _sample_gibbs
65 changes: 65 additions & 0 deletions src/jnotype/_factor_analysis/_simulate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Simulate data sets."""
from jaxtyping import Float, Array

import jax.numpy as jnp
from jax import random


def sample_observations(
key,
lambd: Float[Array, "observed latent"],
eta: Float[Array, "points latent"],
sigma2: Float[Array, " observed"],
) -> Float[Array, "points observed"]:
"""Samples observations.
Args:
key: JAX PRNG key
lambd: Mixing matrix
eta: Latent traits
sigma2: noise variance for each observed dimension
Returns:
Y matrix, shape (n_points, n_observed)
"""
# Shape (N, P)
N = eta.shape[0]
P = lambd.shape[0]
noise = random.normal(key, shape=(N, P)) * jnp.sqrt(sigma2)[None, :]
return jnp.einsum("PH,NH -> NP", lambd, eta) + noise


def sample_latent(key, points: int, latent: int) -> Float[Array, "points latent"]:
"""Samples latent traits.
Args:
key: JAX PRNG key
points: Number of points
latent: Number of latent traits
"""
return random.normal(key, shape=(points, latent))


def sample_mixing(
key, observed: int, theta: Float[Array, " latent"]
) -> Float[Array, "observed latent"]:
"""Samples the mixing matrix.
Args:
key: JAX PRNG key
observed: Number of observed dimensions
theta: variance of factors attributed to each latent trait
"""
latent = theta.shape[0]
lambd = random.normal(key, shape=(observed, latent))
return lambd * jnp.sqrt(theta)[None, ...]


def covariance_from_mixing(
lambd: Float[Array, "observed latent"],
sigma2: Float[Array, " observed"],
) -> Float[Array, "observed observed"]:
"""Compute the covariance matrix of the observed variables
using the mixing matrix
and additional variance for each observed dimension."""
return jnp.einsum("ph,qh -> pq", lambd, lambd) + jnp.diag(sigma2)

0 comments on commit f90b967

Please sign in to comment.