diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf2d363..f98315e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - - repo: https://github.com/python-poetry/poetry - rev: '1.3.2' - hooks: - - id: poetry-check - - id: poetry-lock - - id: poetry-export - args: ["-f", "requirements.txt", "-o", "requirements.txt"] + #- repo: https://github.com/python-poetry/poetry + #rev: '1.3.2' + #hooks: + #- id: poetry-check + #- id: poetry-lock + #- id: poetry-export + #args: ["-f", "requirements.txt", "-o", "requirements.txt"] - repo: https://github.com/charliermarsh/ruff-pre-commit rev: 'v0.0.245' hooks: @@ -21,6 +21,7 @@ repos: - id: interrogate exclude: tests args: [--config=pyproject.toml, --fail-under=95] + pass_filenames: false #- repo: https://github.com/RobertCraigie/pyright-python #rev: v1.1.296 #hooks: diff --git a/conftest.py b/conftest.py index 6ca754b..bb7ac2f 100644 --- a/conftest.py +++ b/conftest.py @@ -1,12 +1,9 @@ """PyTest unit tests configuration file.""" + import dataclasses import pytest -import matplotlib - -matplotlib.use("Agg") - @dataclasses.dataclass class TurnOnTestSuiteArgument: diff --git a/src/jnotype/__init__.py b/src/jnotype/__init__.py index 7ae9bfc..098de06 100644 --- a/src/jnotype/__init__.py +++ b/src/jnotype/__init__.py @@ -1,4 +1,5 @@ """Exploratory analysis of binary data.""" + import jnotype.bmm as bmm import jnotype.datasets as datasets import jnotype.sampling as sampling diff --git a/src/jnotype/_csp.py b/src/jnotype/_csp.py index c446758..ada045f 100644 --- a/src/jnotype/_csp.py +++ b/src/jnotype/_csp.py @@ -1,4 +1,5 @@ """Cumulative shrinkage prior.""" + import jax import jax.numpy as jnp import jax.scipy as jsp diff --git a/src/jnotype/_factor_analysis/_gibbs_backend.py b/src/jnotype/_factor_analysis/_gibbs_backend.py index 8ed08c5..afe7b52 100644 --- a/src/jnotype/_factor_analysis/_gibbs_backend.py +++ b/src/jnotype/_factor_analysis/_gibbs_backend.py @@ -1,5 +1,6 @@ """Sampling steps for all variables, apart from variances attributed to latent traits, which are sampled with CSP module.""" + from typing import Callable import jax diff --git a/src/jnotype/_factor_analysis/_inference.py b/src/jnotype/_factor_analysis/_inference.py index 01a2f6c..e34102c 100644 --- a/src/jnotype/_factor_analysis/_inference.py +++ b/src/jnotype/_factor_analysis/_inference.py @@ -1,4 +1,5 @@ """Sampling from the posterior distribution.""" + from typing import Callable from jaxtyping import Float, Array diff --git a/src/jnotype/_factor_analysis/_simulate.py b/src/jnotype/_factor_analysis/_simulate.py index 7cfd6ff..70fd1a7 100644 --- a/src/jnotype/_factor_analysis/_simulate.py +++ b/src/jnotype/_factor_analysis/_simulate.py @@ -1,4 +1,5 @@ """Simulate data sets.""" + from jaxtyping import Float, Array import jax.numpy as jnp diff --git a/src/jnotype/_utils.py b/src/jnotype/_utils.py index d2f3173..330d24d 100644 --- a/src/jnotype/_utils.py +++ b/src/jnotype/_utils.py @@ -3,6 +3,7 @@ This file should be as small as possible. Appearing themes should be refactored and placed into separate modules.""" + import jax @@ -16,7 +17,7 @@ class JAXRNG: b = jax.random.bernoulli(rng.key, shape=(10,)) """ - def __init__(self, key: jax.random.PRNGKeyArray) -> None: + def __init__(self, key: jax.Array) -> None: """ Args: key: initialization key @@ -24,7 +25,7 @@ def __init__(self, key: jax.random.PRNGKeyArray) -> None: self._key = key @property - def key(self) -> jax.random.PRNGKeyArray: + def key(self) -> jax.Array: """Generates a new key.""" key, subkey = jax.random.split(self._key) self._key = key diff --git a/src/jnotype/_variance.py b/src/jnotype/_variance.py index e3abbf8..a74b95a 100644 --- a/src/jnotype/_variance.py +++ b/src/jnotype/_variance.py @@ -1,4 +1,5 @@ """Utilities for sampling variances.""" + from jax import random import jax import jax.numpy as jnp @@ -7,7 +8,7 @@ def _sample_precisions( - key: random.PRNGKeyArray, + key: jax.Array, values: Float[Array, "G F"], mask: Int[Array, "G F"], prior_shape: float, @@ -34,7 +35,7 @@ def _sample_precisions( @jax.jit def sample_variances( - key: random.PRNGKeyArray, + key: jax.Array, values: Float[Array, "G F"], mask: Int[Array, "G F"], prior_shape: float = 2.0, diff --git a/src/jnotype/bmm/__init__.py b/src/jnotype/bmm/__init__.py index 19f4573..8691fd1 100644 --- a/src/jnotype/bmm/__init__.py +++ b/src/jnotype/bmm/__init__.py @@ -1,4 +1,5 @@ """Bernoulli Mixture Model.""" + from jnotype.bmm._em import expectation_maximization from jnotype.bmm._gibbs import ( sample_mixing, diff --git a/src/jnotype/bmm/_em.py b/src/jnotype/bmm/_em.py index a9db452..e41e0b4 100644 --- a/src/jnotype/bmm/_em.py +++ b/src/jnotype/bmm/_em.py @@ -1,4 +1,5 @@ """The Expectation-Maximization algorithm for Bernoulli Mixture Model.""" + import dataclasses import time from typing import Optional @@ -68,7 +69,11 @@ def em_step( observed: Int[Array, "N K"], mixing: Float[Array, "K B"], proportions: Float[Array, " B"], -) -> tuple[Float[Array, "N B"], Float[Array, "K B"], Float[Array, " B"],]: +) -> tuple[ + Float[Array, "N B"], + Float[Array, "K B"], + Float[Array, " B"], +]: """The E and M step combined, for better JIT compiler optimisation. Args: @@ -160,7 +165,7 @@ def _init( *, n_features: int, n_clusters: Optional[int], # type: ignore - key: Optional[jax.random.PRNGKeyArray], + key: Optional[jax.Array], mixing_init: Optional[Float[Array, "K B"]], proportions_init: Optional[Float[Array, " B"]], ) -> tuple[Float[Array, "K B"], Float[Array, " B"]]: @@ -221,7 +226,7 @@ def expectation_maximization( observed: Int[Array, "K B"], *, n_clusters: Optional[int] = None, - key: Optional[jax.random.PRNGKeyArray] = None, + key: Optional[jax.Array] = None, mixing_init: Optional[Float[Array, "K B"]] = None, proportions_init: Optional[Float[Array, " B"]] = None, max_n_steps: int = 10_000, diff --git a/src/jnotype/bmm/_gibbs.py b/src/jnotype/bmm/_gibbs.py index 0d6c0f4..cc44d0a 100644 --- a/src/jnotype/bmm/_gibbs.py +++ b/src/jnotype/bmm/_gibbs.py @@ -1,4 +1,5 @@ """Sampling cluster labels and proportions.""" + from typing import Optional, Sequence import jax @@ -41,7 +42,7 @@ def _bernoulli_loglikelihood( def sample_cluster_labels( - key: random.PRNGKeyArray, + key: jax.Array, *, cluster_proportions: Float[Array, " B"], mixing: Float[Array, "K B"], @@ -94,7 +95,7 @@ def _calculate_counts(labels: Int[Array, " N"], n_clusters: int) -> Int[Array, " def sample_cluster_proportions( - key: random.PRNGKeyArray, + key: jax.Array, *, labels: Int[Array, " N"], dirichlet_prior: Float[Array, " B"], @@ -117,7 +118,7 @@ def sample_cluster_proportions( def sample_mixing( - key: random.PRNGKeyArray, + key: jax.Array, *, observations: Int[Array, "N K"], labels: Int[Array, " N"], @@ -148,7 +149,7 @@ def sample_mixing( @jax.jit def single_sampling_step( *, - key: random.PRNGKeyArray, + key: jax.Array, observed_data: Int[Array, "N K"], proportions: Float[Array, " B"], mixing: Float[Array, "K B"], @@ -210,7 +211,7 @@ def __init__( dirichlet_prior: Float[Array, " cluster"], beta_prior: tuple[float, float] = (1.0, 1.0), *, - jax_rng_key: Optional[jax.random.PRNGKeyArray] = None, + jax_rng_key: Optional[jax.Array] = None, warmup: int = 2000, steps: int = 3000, verbose: bool = False, diff --git a/src/jnotype/checks/_histograms.py b/src/jnotype/checks/_histograms.py index 35ba21c..6577570 100644 --- a/src/jnotype/checks/_histograms.py +++ b/src/jnotype/checks/_histograms.py @@ -1,4 +1,5 @@ """Plotting histograms of data.""" + from typing import Sequence, Union, Literal import matplotlib.pyplot as plt diff --git a/src/jnotype/datasets/__init__.py b/src/jnotype/datasets/__init__.py index dd99224..64d6e62 100644 --- a/src/jnotype/datasets/__init__.py +++ b/src/jnotype/datasets/__init__.py @@ -1,4 +1,5 @@ """Data sets.""" + from jnotype.datasets._simulation import BlockImagesSampler __all__ = [ diff --git a/src/jnotype/datasets/_simulation/__init__.py b/src/jnotype/datasets/_simulation/__init__.py index cf48708..46ee3aa 100644 --- a/src/jnotype/datasets/_simulation/__init__.py +++ b/src/jnotype/datasets/_simulation/__init__.py @@ -1,4 +1,5 @@ """Simulated data sets.""" + from jnotype.datasets._simulation._block_images import BlockImagesSampler __all__ = [ diff --git a/src/jnotype/datasets/_simulation/_block_images.py b/src/jnotype/datasets/_simulation/_block_images.py index 2f3b881..c3eeffd 100644 --- a/src/jnotype/datasets/_simulation/_block_images.py +++ b/src/jnotype/datasets/_simulation/_block_images.py @@ -1,8 +1,10 @@ """Simulation of binary images using Bernoulli mixture model.""" + from typing import Optional from jaxtyping import Array, Float, Int import jax.numpy as jnp +import jax from jax import random @@ -139,7 +141,7 @@ def mixing(self) -> Float[Array, "n_classes 6 6"]: def sample_dataset( self, - key: random.PRNGKeyArray, + key: jax.Array, n_samples: int, *, probs: Optional[Float[Array, " n_classes"]] = None, diff --git a/src/jnotype/logistic/_binary_latent.py b/src/jnotype/logistic/_binary_latent.py index 6334688..02f0ffe 100644 --- a/src/jnotype/logistic/_binary_latent.py +++ b/src/jnotype/logistic/_binary_latent.py @@ -1,4 +1,5 @@ """Sample binary latent variables.""" + from functools import partial import jax @@ -13,7 +14,7 @@ @partial(jax.jit, static_argnames="n_binary_codes") def sample_binary_codes( *, - key: random.PRNGKeyArray, + key: jax.Array, intercepts: Float[Array, " features"], coefficients: Float[Array, "features covs"], structure: Int[Array, "features covs"], diff --git a/src/jnotype/logistic/_polyagamma.py b/src/jnotype/logistic/_polyagamma.py index 4fb9144..9fb510f 100644 --- a/src/jnotype/logistic/_polyagamma.py +++ b/src/jnotype/logistic/_polyagamma.py @@ -1,4 +1,5 @@ """Logistic regression sampling utilities using Pólya-Gamma augmentation.""" + from jax import random import jax import jax.numpy as jnp @@ -22,7 +23,7 @@ def _calculate_logits( def _sample_coefficients( *, - key: random.PRNGKeyArray, + key: jax.Array, omega: Float[Array, "points features"], covariates: Float[Array, "points covariates"], structure: Int[Array, "features covariates"], @@ -53,9 +54,9 @@ def _sample_coefficients( precision_matrices: Float[Array, "features covariates covariates"] = jax.vmap( jnp.diag )(jnp.reciprocal(prior_variance)) - posterior_covariances: Float[ - Array, "features covariates covariates" - ] = jnp.linalg.inv(x_omega_x + precision_matrices) + posterior_covariances: Float[Array, "features covariates covariates"] = ( + jnp.linalg.inv(x_omega_x + precision_matrices) + ) kappa: Float[Array, "points features"] = jnp.asarray(observed, dtype=float) - 0.5 @@ -79,7 +80,7 @@ def _sample_coefficients( def sample_coefficients( *, - jax_key: random.PRNGKeyArray, + jax_key: jax.Array, numpy_rng: np.random.Generator, observed: Int[Array, "points features"], design_matrix: Float[Array, "points covariates"], @@ -214,7 +215,7 @@ def _augment_matrices( def sample_intercepts_and_coefficients( *, - jax_key: random.PRNGKeyArray, + jax_key: jax.Array, numpy_rng: np.random.Generator, observed: Int[Array, "points features"], intercepts: Float[Array, " features"], diff --git a/src/jnotype/logistic/_structure.py b/src/jnotype/logistic/_structure.py index a28470b..8224fda 100644 --- a/src/jnotype/logistic/_structure.py +++ b/src/jnotype/logistic/_structure.py @@ -1,4 +1,5 @@ """Sample structure (spike/slab distinction) variables.""" + from typing import Union import jax @@ -41,7 +42,7 @@ def _logpdf_gaussian( @jax.jit def sample_structure( *, - key: random.PRNGKeyArray, + key: jax.Array, intercepts: Float[Array, " features"], coefficients: Float[Array, "features covs"], structure: Int[Array, "features covs"], @@ -158,7 +159,7 @@ def body_fun( @jax.jit def sample_gamma( - key: random.PRNGKeyArray, + key: jax.Array, structure: Int[Array, "G K"], prior_a: float = 1.0, prior_b: float = 1.0, @@ -177,7 +178,7 @@ def sample_gamma( @jax.jit def sample_gamma_individual( - key: random.PRNGKeyArray, + key: jax.Array, structure: Int[Array, "G K"], prior_a: Float[Array, " K"], prior_b: Float[Array, " K"], diff --git a/src/jnotype/logistic/logreg.py b/src/jnotype/logistic/logreg.py index 791b9e1..6d140aa 100644 --- a/src/jnotype/logistic/logreg.py +++ b/src/jnotype/logistic/logreg.py @@ -1,4 +1,5 @@ """Logistic regression utilities.""" + import jax import jax.numpy as jnp from jaxtyping import Int, Float, Array diff --git a/src/jnotype/pyramids/__init__.py b/src/jnotype/pyramids/__init__.py index b12fb4a..1c6c801 100644 --- a/src/jnotype/pyramids/__init__.py +++ b/src/jnotype/pyramids/__init__.py @@ -1,7 +1,9 @@ """Gibbs samplers for Bayesian pyramids.""" from jnotype.pyramids._sampler_fixed import TwoLayerPyramidSampler +from jnotype.pyramids._sampler_csp import TwoLayerPyramidSamplerNonparametric __all__ = [ "TwoLayerPyramidSampler", + "TwoLayerPyramidSamplerNonparametric", ] diff --git a/src/jnotype/pyramids/_sampler_csp.py b/src/jnotype/pyramids/_sampler_csp.py new file mode 100644 index 0000000..28b684d --- /dev/null +++ b/src/jnotype/pyramids/_sampler_csp.py @@ -0,0 +1,452 @@ +"""Sampler for two-layer Bayesian pyramids +with cumulative shrinkage process (CSP) prior +on latent binary codes.""" + +from typing import Optional, Sequence, Union, NewType + +import jax +import jax.numpy as jnp +import numpy as np + +from jaxtyping import Array, Float, Int + +from jnotype.sampling import AbstractGibbsSampler, DatasetInterface +from jnotype._utils import JAXRNG + +from jnotype.bmm import sample_bmm +from jnotype.logistic import ( + sample_gamma, + sample_gamma_individual, + sample_structure, + sample_binary_codes, + sample_intercepts_and_coefficients, +) +from jnotype._csp import sample_csp_gibbs, sample_csp_prior +from jnotype._variance import sample_variances + +Sample = NewType("Sample", dict) + + +_sample_csp_gibbs_jit = jax.jit(sample_csp_gibbs) + + +def _single_sampling_step( + *, + # Auxiliary: random keys, static specification + jax_key: jax.Array, + numpy_rng: np.random.Generator, + n_binary_codes: int, + # Observed values + observed: Int[Array, "points observed"], + # Sampled variables + intercepts: Float[Array, " observed"], + coefficients: Float[Array, "observed covariates"], + structure: Int[Array, "observed covariates"], + covariates: Float[Array, "points covariates"], + observed_variances: Float[Array, " known_covariates"], + gamma: Float[Array, ""], + nu: Float[Array, " known_covariates"], + cluster_labels: Int[Array, " points"], + mixing: Float[Array, "n_binary_codes n_clusters"], + proportions: Float[Array, " n_clusters"], + csp_omega: Float[Array, " n_binary_codes"], + csp_expected_occupied: float, + # Priors + dirichlet_prior: Float[Array, " n_clusters"], + nu_prior_a: Union[float, Float[Array, " known_covariates"]], + nu_prior_b: Union[float, Float[Array, " known_covariates"]], + pseudoprior_variance: float, + intercept_prior_mean: float, + intercept_prior_variance: float, + gamma_prior_a: float, + gamma_prior_b: float, + mixing_beta_prior: tuple[float, float], + variances_prior_shape: float, + variances_prior_scale: float, + csp_theta_inf: float, +) -> Sample: + """Single sampling step of two-layer Bayesian pyramid. + + Note: + It cannot be JITed and it uses JAX random numbers + as well as NumPy's. This is necessary to use + the Pólya-Gamma sampler (coefficients). + """ + # --- Sample variances for latent variables from the CSP prior --- + key, subkey = jax.random.split(jax_key) + csp_sample = _sample_csp_gibbs_jit( + key=subkey, + coefficients=coefficients[:, :n_binary_codes], + structure=structure[:, :n_binary_codes], + omega=csp_omega, + expected_occupied=csp_expected_occupied, + prior_shape=variances_prior_shape, + prior_scale=variances_prior_scale, + theta_inf=csp_theta_inf, + ) + # --- Sample variances for observed variables from the usual prior --- + key, subkey = jax.random.split(key) + observed_variances = sample_variances( + key=subkey, + values=coefficients[:, n_binary_codes:], + mask=structure[:, n_binary_codes:], + prior_shape=variances_prior_shape, + prior_scale=variances_prior_scale, + ) + variances = jnp.concatenate((csp_sample["variance"], observed_variances)) + + # --- Sample the sparse logistic regression layer --- + # Sample intercepts and coefficients + key, subkey = jax.random.split(jax_key) + + intercepts, coefficients = sample_intercepts_and_coefficients( + jax_key=subkey, + numpy_rng=numpy_rng, + intercepts=intercepts, + coefficients=coefficients, + structure=structure, + covariates=covariates, + variances=variances, + pseudoprior_variance=pseudoprior_variance, + intercept_prior_mean=intercept_prior_mean, + intercept_prior_variance=intercept_prior_variance, + observed=observed, + ) + + # Sample structure and the sparsity + key, subkey_structure, subkey_gamma, subkey_nu = jax.random.split(key, 4) + + sparsity_vector = jnp.concatenate( + ( + jnp.full(shape=(n_binary_codes,), fill_value=gamma), + nu, + ) + ) + structure = sample_structure( + key=subkey_structure, + intercepts=intercepts, + coefficients=coefficients, + structure=structure, + covariates=covariates, + observed=observed, + variances=variances, + pseudoprior_variance=pseudoprior_variance, + gamma=sparsity_vector, + ) + gamma = sample_gamma( + key=subkey_gamma, + structure=structure[..., :n_binary_codes], + prior_a=gamma_prior_a, + prior_b=gamma_prior_b, + ) + nu = sample_gamma_individual( + key=subkey_nu, + structure=structure[..., n_binary_codes:], + prior_a=nu_prior_a, + prior_b=nu_prior_b, + ) + + # Sample binary latent variables + key, subkey = jax.random.split(key) + covariates = sample_binary_codes( + key=subkey, + intercepts=intercepts, + coefficients=coefficients, + structure=structure, + covariates=covariates, + observed=observed, + n_binary_codes=n_binary_codes, + labels=cluster_labels, + labels_to_codes=mixing, + ) + + # --- Sample the Bernoulli mixture model layer --- + key, subkey = jax.random.split(key) + cluster_labels, proportions, mixing = sample_bmm( + key=subkey, + # Bernoulli mixture model sees the binary latent codes + observed_data=covariates[:, :n_binary_codes], + proportions=proportions, + mixing=mixing, + dirichlet_prior=dirichlet_prior, + beta_prior=mixing_beta_prior, + ) + + return { + # Intercepts and coefficients + "intercepts": intercepts, + "coefficients_latent": coefficients[:, :n_binary_codes], + "coefficients_observed": coefficients[:, n_binary_codes:], + # Structure and sparsities + "structure_latent": structure[:, :n_binary_codes], + "structure_observed": structure[:, n_binary_codes:], + "gamma": gamma, + "nu": nu, + # Latent traits + "latent_traits": covariates[:, :n_binary_codes], + # Clustering + "cluster_labels": cluster_labels, + "proportions": proportions, + "mixing": mixing, + # Variances: + # - Variances for observed covariates + "observed_variances": observed_variances, + # - Variances for latent binary codes, using CSP prior + "latent_variances": csp_sample["variance"], + "csp_omega": csp_sample["omega"], + "csp_nu": csp_sample["nu"], + "csp_indicators": csp_sample["indicators"], + "csp_active_traits": csp_sample["active_traits"], + "csp_n_active_traits": csp_sample["n_active"], + } + + +class TwoLayerPyramidSamplerNonparametric(AbstractGibbsSampler): + """A prototype of a Gibbs sampler for a two-layer + Bayesian pyramid with CSP prior. + """ + + def __init__( + self, + datasets: Sequence[DatasetInterface], + *, + # Observed data and dimension specification + observed: Int[Array, "points features"], + expected_binary_codes: float = 4, + max_binary_codes: int = 8, + n_clusters: int = 10, + observed_covariates: Optional[ + Float[Array, "points observed_covariates"] + ] = None, + # Prior + dirichlet_prior: Float[Array, " clusters"], + gamma_prior: tuple[float, float] = (1.0, 1.0), + nu_prior: tuple[float, float] = (1.0, 1.0), + variances_prior_scale: float = 2.0, + variances_prior_shape: float = 1.0, + pseudoprior_variance: float = 0.1**2, + mixing_beta_prior: tuple[float, float] = (1.0, 5.0), + intercept_prior_mean: float = -3, + intercept_prior_variance: float = 1.0**2, + inactive_latent_variance_theta_inf: float = 0.1**2, + # Gibbs sampling + warmup: int = 5_000, + steps: int = 10_000, + verbose: bool = False, + seed: int = 195, + ) -> None: + super().__init__(datasets, warmup=warmup, steps=steps, verbose=verbose) + + # Initialize two random number generators: we cannot just use JAX + # here + self._jax_rng = JAXRNG(jax.random.PRNGKey(seed)) + self._np_rng = np.random.default_rng(seed + 3) + + self._observed_data = observed + self._observed_covariates = ( + observed_covariates + if observed_covariates is not None + else jnp.zeros((observed.shape[0], 0)) + ) + self._n_observed_covariates = self._observed_covariates.shape[1] + + assert inactive_latent_variance_theta_inf > 0 + self._csp_theta_inf = inactive_latent_variance_theta_inf + + assert n_clusters >= 1 + self._n_clusters = n_clusters + + # We have number of binary codes modelled (maximum one) + # and number of binary codes expected a priori, as the rest + # will be marked as inactive. + assert max_binary_codes >= 1 + self._n_binary_codes = max_binary_codes + assert expected_binary_codes > 0 + self._kappa_0 = expected_binary_codes + + self._dirichlet_prior = dirichlet_prior + self._gamma_prior = gamma_prior + self._nu_prior = nu_prior + + # Variances of coefficients per each covariate in middle layer + self._variances_prior_scale = variances_prior_scale + self._variances_prior_shape = variances_prior_shape + self._pseudoprior_variance = pseudoprior_variance + + self._mixing_beta_prior = mixing_beta_prior + + self._intercept_prior_mean = intercept_prior_mean + self._intercept_prior_variance = intercept_prior_variance + + @classmethod + def dimensions(cls) -> Sample: + """The sites in each sample with annotated dimensions.""" + return { + "intercepts": ["features"], + "coefficients_latent": ["features", "latents"], + "coefficients_observed": ["features", "observed_covariates"], + "structure_latent": ["features", "latents"], + "structure_observed": ["features", "observed_covariates"], + "gamma": [], # Float, no named dimensions + "nu": ["observed_covariates"], + "latent_traits": ["points", "latents"], + "cluster_labels": ["points"], + "proportions": ["clusters"], + "mixing": ["latents", "clusters"], + # Obseved variances + "observed_variances": ["observed_covariates"], + # Latent variances are modelled using the CSP prior + "latent_variances": ["latents"], + "csp_omega": ["latents"], + "csp_nu": ["latents"], + "csp_indicators": ["latents"], + "csp_active_traits": ["latents"], + "csp_n_active_traits": [], # Int, no named dimensions + } + + def new_sample(self, sample: Sample) -> Sample: + """A new sample.""" + coefficients = jnp.hstack( + (sample["coefficients_latent"], sample["coefficients_observed"]) + ) + structure = jnp.hstack( + (sample["structure_latent"], sample["structure_observed"]) + ) + covariates = jnp.hstack((sample["latent_traits"], self._observed_covariates)) + return _single_sampling_step( + jax_key=self._jax_rng.key, + numpy_rng=self._np_rng, + n_binary_codes=self._n_binary_codes, + observed=self._observed_data, + intercepts=sample["intercepts"], + coefficients=coefficients, + structure=structure, + covariates=covariates, + observed_variances=sample["observed_variances"], + gamma=sample["gamma"], + nu=sample["nu"], + cluster_labels=sample["cluster_labels"], + mixing=sample["mixing"], + proportions=sample["proportions"], + dirichlet_prior=self._dirichlet_prior, + gamma_prior_a=self._gamma_prior[0], + gamma_prior_b=self._gamma_prior[1], + nu_prior_a=self._nu_prior[0], + nu_prior_b=self._nu_prior[1], + variances_prior_scale=self._variances_prior_scale, + variances_prior_shape=self._variances_prior_shape, + mixing_beta_prior=self._mixing_beta_prior, + pseudoprior_variance=self._pseudoprior_variance, + intercept_prior_mean=self._intercept_prior_mean, + intercept_prior_variance=self._intercept_prior_variance, + csp_omega=sample["csp_omega"], + csp_expected_occupied=self._kappa_0, + csp_theta_inf=self._csp_theta_inf, + ) + + def _initialise_intercepts(self) -> Float[Array, " covariates"]: + """Initializes the intercepts.""" + n_outputs = self._observed_data.shape[1] + + mean = self._intercept_prior_mean + std = np.sqrt(self._intercept_prior_variance) + normal_noise = jax.random.normal(self._jax_rng.key, shape=(n_outputs,)) + return mean + std * normal_noise + + def _initialise_gamma(self) -> Float[Array, ""]: + return jax.random.beta( + self._jax_rng.key, self._gamma_prior[0], self._gamma_prior[1] + ) + + def _initialise_nu(self) -> Float[Array, ""]: + return jax.random.beta( + self._jax_rng.key, + self._nu_prior[0], + self._nu_prior[1], + shape=(self._n_observed_covariates,), + ) + + def _initialise_structure( + self, gamma: Float[Array, ""], nu: Float[Array, " observed_covariates"] + ) -> Int[Array, ""]: + """Initialises the structure.""" + n_outputs = self._observed_data.shape[1] + structure_codes = jax.random.bernoulli( + self._jax_rng.key, p=gamma, shape=(n_outputs, self._n_binary_codes) + ) + structure_observed_features = jax.random.bernoulli( + self._jax_rng.key, p=nu, shape=(n_outputs, self._n_observed_covariates) + ) + return jnp.hstack((structure_codes, structure_observed_features)) + + def initialise(self) -> Sample: + """Initialises the sample.""" + # TODO(Pawel): This initialisation can be much improved. + # Hopefully it does not matter in the end, but assessment of + # chain mixing is very much required. + + n_covariates = self._n_binary_codes + self._n_observed_covariates + n_points, n_outputs = self._observed_data.shape + n_clusters = self._n_clusters + + gamma = self._initialise_gamma() + nu = self._initialise_nu() + + initial_binary_codes = jnp.asarray( + # TODO(Pawel): There could be a better way to initialize binary codes + jax.random.bernoulli( + self._jax_rng.key, p=0.1, shape=(n_points, self._n_binary_codes) + ), + dtype=float, + ) + initial_covariates = jnp.hstack( + (initial_binary_codes, self._observed_covariates) + ) + + structure = self._initialise_structure(gamma=gamma, nu=nu) + + csp_sample = sample_csp_prior( + self._jax_rng.key, + k=self._n_binary_codes, + expected_occupied=self._kappa_0, + theta_inf=self._csp_theta_inf, + ) + + variances = jnp.concatenate( + (csp_sample["variance"], jnp.ones(self._n_observed_covariates)) + ) + + _noise = jax.random.normal(self._jax_rng.key, shape=(n_outputs, n_covariates)) + _entries_variances = variances[ + None, : + ] * structure + self._pseudoprior_variance * (1 - structure) + coefficients = jnp.sqrt(_entries_variances) * _noise + + return { + "intercepts": self._initialise_intercepts(), + "coefficients_latent": coefficients[:, : self._n_binary_codes], + "coefficients_observed": coefficients[:, self._n_binary_codes :], + "structure_latent": structure[:, : self._n_binary_codes], + "structure_observed": structure[:, self._n_binary_codes :], + "gamma": gamma, + "nu": nu, + "latent_traits": initial_binary_codes, + "covariates": initial_covariates, + "cluster_labels": jax.random.categorical( + self._jax_rng.key, logits=jnp.zeros(n_clusters), shape=(n_points,) + ), + "proportions": jnp.full(fill_value=1.0 / n_clusters, shape=(n_clusters,)), + "mixing": jax.random.beta( + self._jax_rng.key, + a=self._mixing_beta_prior[0], + b=self._mixing_beta_prior[1], + shape=(self._n_binary_codes, n_clusters), + ), + "observed_variances": variances[self._n_binary_codes :], + "latent_variances": csp_sample["variance"], + "csp_omega": csp_sample["omega"], + "csp_nu": csp_sample["nu"], + "csp_indicators": csp_sample["indicators"], + "csp_active_traits": csp_sample["active_traits"], + "csp_n_active_traits": csp_sample["n_active"], + } diff --git a/src/jnotype/pyramids/_sampler_fixed.py b/src/jnotype/pyramids/_sampler_fixed.py index 0516666..4677a8e 100644 --- a/src/jnotype/pyramids/_sampler_fixed.py +++ b/src/jnotype/pyramids/_sampler_fixed.py @@ -1,5 +1,6 @@ """Sampler for two-layer Bayesian pyramids with fixed number of latent binary codes.""" + from typing import Optional, Sequence, Union, NewType import jax @@ -28,7 +29,7 @@ def _single_sampling_step( *, # Auxiliary: random keys, static specification - jax_key: jax.random.PRNGKeyArray, + jax_key: jax.Array, numpy_rng: np.random.Generator, n_binary_codes: int, # Observed values @@ -399,10 +400,10 @@ def _initialise_full_sample(self) -> _JointSample: return { "intercepts": self._initialise_intercepts(), "coefficients": coefficients, - "structure": self._initialise_structure(gamma=gamma, nu=nu), + "structure": structure, "gamma": gamma, "nu": nu, - "variances": jnp.ones(n_covariates), + "variances": variances, "covariates": initial_covariates, "cluster_labels": jax.random.categorical( self._jax_rng.key, logits=jnp.zeros(n_clusters), shape=(n_points,) diff --git a/src/jnotype/sampling/__init__.py b/src/jnotype/sampling/__init__.py index ddef699..04af43d 100644 --- a/src/jnotype/sampling/__init__.py +++ b/src/jnotype/sampling/__init__.py @@ -1,4 +1,5 @@ """Generic utilities for sampling.""" + from jnotype.sampling._chunker import ( DatasetInterface, ListDataset, diff --git a/src/jnotype/sampling/_chunker.py b/src/jnotype/sampling/_chunker.py index 119d1cc..ad22982 100644 --- a/src/jnotype/sampling/_chunker.py +++ b/src/jnotype/sampling/_chunker.py @@ -1,4 +1,5 @@ """Utilities for saving samples in chunks, to limit RAM usage.""" + import abc from datetime import datetime diff --git a/src/jnotype/sampling/_sampler.py b/src/jnotype/sampling/_sampler.py index 7191c15..5783533 100644 --- a/src/jnotype/sampling/_sampler.py +++ b/src/jnotype/sampling/_sampler.py @@ -1,4 +1,5 @@ """Generic Gibbs sampler.""" + import abc import logging import time diff --git a/tests/logistic/test_logreg.py b/tests/logistic/test_logreg.py index 2149db2..3466830 100644 --- a/tests/logistic/test_logreg.py +++ b/tests/logistic/test_logreg.py @@ -1,4 +1,5 @@ """Tests for the logistic regression.""" + import jax.numpy as jnp import jnotype.logistic.logreg as lr diff --git a/tests/logistic/test_polyagamma.py b/tests/logistic/test_polyagamma.py index 0019506..dab6764 100644 --- a/tests/logistic/test_polyagamma.py +++ b/tests/logistic/test_polyagamma.py @@ -153,9 +153,7 @@ def test_sample_coefficients_nontrivial_structure( ] ) - prior_variance = jnp.asarray( - [[0.6, 0.02**2], [0.03**2, 0.5], [0.04**2, 0.05**2]] - ) + prior_variance = jnp.asarray([[0.6, 0.02**2], [0.03**2, 0.5], [0.04**2, 0.05**2]]) current_coefficients = jnp.zeros_like(true_coefficients) diff --git a/tests/test_csp.py b/tests/test_csp.py index fce493b..2b2f443 100644 --- a/tests/test_csp.py +++ b/tests/test_csp.py @@ -1,4 +1,5 @@ """Tests of the cumulative shrinkage prior (CSP) module.""" + import jax import jax.numpy as jnp from jax import random diff --git a/tests/test_install.py b/tests/test_install.py index e5b3a42..fe42cdd 100644 --- a/tests/test_install.py +++ b/tests/test_install.py @@ -1,4 +1,5 @@ """Simplest installation test.""" + from types import ModuleType diff --git a/workflows/TCGA_analysis.smk b/workflows/TCGA_analysis.smk new file mode 100644 index 0000000..538f205 --- /dev/null +++ b/workflows/TCGA_analysis.smk @@ -0,0 +1,648 @@ +import dataclasses +import matplotlib +import matplotlib.pyplot as plt +import matplotlib.patches as mpatches +import numpy as np +import pandas as pd +import xarray as xr +import sys +import io +import json +import seaborn as sns +from scipy.stats import chi2 + +from sklearn.preprocessing import StandardScaler, OneHotEncoder + +import formulaic as fm +from lifelines import CoxPHFitter +from lifelines.calibration import survival_probability_calibration + +from jnotype.pyramids import TwoLayerPyramidSampler, TwoLayerPyramidSamplerNonparametric +from jnotype.sampling import ListDataset + + +matplotlib.use("agg") + + +@dataclasses.dataclass +class AnalysisParams: + # --- Data specification --- + allowed_types: list[str] + + # --- Bayesian pyramids --- + # Model specification + design_formula: str = "std_age + gender + type + stage" + n_clusters: int = 4 + max_binary_codes: int = 8 + expected_binary_codes: float = 3.0 + # MCMC parameters + n_warmup: int = 12_000 + n_steps: int = 3000 + thinning: int = 5 + + # --- Survival analysis --- + penalizer: float = 0.05 + +pancancer_list = ["BRCA", "LGG", "HNSC", "PRAD", "THCA", "OV", "LUAD"] + +ANALYSES = { + # "AML": AnalysisParams(allowed_types=["LAML"]), + "BRCA": AnalysisParams(allowed_types=["BRCA"]), + "BRCA-Minimal": AnalysisParams(allowed_types=["BRCA"], design_formula="std_age + gender"), + "Brain": AnalysisParams(allowed_types=["LGG", "GBM"]), + "Brain-Minimal": AnalysisParams(allowed_types=["LGG", "GBM"], design_formula="std_age + gender"), + "Pancancer": AnalysisParams(allowed_types=pancancer_list), + "Pancancer-Minimal": AnalysisParams(allowed_types=pancancer_list, design_formula="std_age + gender"), +} + +N_BOOTSTRAP: int = 20 +BOOTSTRAP_INDICES = list(range(1, N_BOOTSTRAP + 1)) + +rule all: + input: + survival_plots = expand("generated/TCGA/{analysis}/summary/survival/plot.pdf", analysis=ANALYSES.keys()), + latent_traits_plots = expand("generated/TCGA/{analysis}/summary/latent_traits.pdf", analysis=ANALYSES.keys()), + effect_sizes_plot = expand("generated/TCGA/{analysis}/summary/effect_sizes.pdf", analysis=ANALYSES.keys()), + learned_mutations_and_traits_plots = expand("generated/TCGA/{analysis}/bootstraps/{bootstrap}/plot_mutations_and_traits.pdf", analysis=ANALYSES.keys(), bootstrap=BOOTSTRAP_INDICES) + +rule download_clinical: + output: "data/TCGA/raw/clinical-information.tsv" + shell: "wget https://raw.githubusercontent.com/cbg-ethz/graphClust_NeurIPS/main/tcga_analysis/data/tcga-clinical-information.txt -O {output}" + + +rule download_mutation: + output: "data/TCGA/raw/mutation-matrix.tsv" + shell: "wget https://raw.githubusercontent.com/cbg-ethz/graphClust_NeurIPS/main/tcga_analysis/data/binary-mutationCovariate-matrix.txt -O {output}" + + +rule preprocess_data: + input: + clinical = "data/TCGA/raw/clinical-information.tsv", + mutation = "data/TCGA/raw/mutation-matrix.tsv" + output: + clinical = "generated/TCGA/{analysis}/preprocessed/clinical-information.csv", + mutation = "generated/TCGA/{analysis}/preprocessed/mutation-matrix.csv" + run: + # Read mutation matrix and remove binarised covariates + mutations = pd.read_csv(input.mutation, sep="\t", index_col=0) + mutations = mutations.drop( + ["Age", "Gender", "Stage"] + ["BLCA", "BRCA", "CESC", "COAD", "READ", "ESCA", "GBM", "HNSC", "KIRC", "KIRP", "LAML", "LGG", "LIHC", "LUAD", "LUSC", "OV", "PAAD", "PCPG", "PRAD", "SARC", "STAD", "THCA", "UCEC"], + axis="columns", + ) + + # Read clinical information and remove binarised (derived) information + clinical = pd.read_csv(input.clinical, sep="\t", index_col=0) + clinical = clinical.drop(["age.bin", "gender.bin", "stage.bin"], axis="columns") + + # Select cancers corresponding to selected tissue types + allowed_tissue_types = ANALYSES[wildcards.analysis].allowed_types + clinical = clinical[clinical["type"].isin(allowed_tissue_types)] + + # Align data frames + mutations, clinical = mutations.align(clinical, join="inner", axis=0) + + # Remove mutations which are constant for all patients + mutations = mutations.loc[:, mutations.nunique() > 1] + + # Add standardized age + clinical['std_age'] = (clinical['age'] - clinical['age'].mean()) / clinical['age'].std() + + mutations.to_csv(output.mutation, index=True) + clinical.to_csv(output.clinical, index=True) + + +rule bootstrap_data: + input: + clinical = "generated/TCGA/{analysis}/preprocessed/clinical-information.csv", + mutation = "generated/TCGA/{analysis}/preprocessed/mutation-matrix.csv" + output: + clinical = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/clinical-information.csv", + mutation = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/mutation-matrix.csv" + run: + clinical = pd.read_csv(input.clinical, index_col=0) + mutation = pd.read_csv(input.mutation, index_col=0) + assert len(clinical) == len(mutation), "Data frames are of different length" + + # Bootstrap samples from both data frames + rng = np.random.default_rng(int(wildcards.bootstrap)) + idx = rng.choice(np.arange(len(clinical)), size=len(clinical), replace=True) + bootstrap_clinical = clinical.iloc[idx] + bootstrap_mutation = mutation.iloc[idx] + + bootstrap_clinical.to_csv(output.clinical, index=True) + bootstrap_mutation.to_csv(output.mutation, index=True) + + +rule fit_pyramid: + input: + clinical = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/clinical-information.csv", + mutation = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/mutation-matrix.csv" + output: + posterior_samples = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/posterior_samples.nc", + observed_covariates = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/observed_covariates.csv" + run: + clinical = pd.read_csv(input.clinical, index_col=0) + mutations = pd.read_csv(input.mutation, index_col=0) + assert len(clinical) == len(mutations), "Data frames are of different length" + + specs = ANALYSES[wildcards.analysis] + + # Create design matrix, encoding observed covariates + observed_covariates = fm.model_matrix(specs.design_formula, clinical) + observed_covariates = observed_covariates.drop("Intercept", axis=1) + observed_covariates.to_csv(output.observed_covariates, index=True) + + design_matrix = observed_covariates.values + + dataset = ListDataset(thinning=specs.thinning, dimensions=TwoLayerPyramidSamplerNonparametric.dimensions()) + + sampler = TwoLayerPyramidSamplerNonparametric( + datasets=[dataset], + observed=mutations.values, + observed_covariates=design_matrix, + dirichlet_prior=np.ones(specs.n_clusters) / specs.n_clusters, + n_clusters=specs.n_clusters, + max_binary_codes=specs.max_binary_codes, + expected_binary_codes=specs.expected_binary_codes, + verbose=True, + warmup=specs.n_warmup, + steps=specs.n_steps, + inactive_latent_variance_theta_inf = 0.1**2, + intercept_prior_variance=1.0**2, + pseudoprior_variance=0.1**2, + seed=int(wildcards.bootstrap), + ) + sampler.run() + dataset.dataset.to_netcdf(output.posterior_samples) + + +rule extract_latents: + input: + posterior_samples = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/posterior_samples.nc", + output: + latent_traits = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/latent_traits.npz" + run: + dataset = xr.open_dataset(input.posterior_samples) + # Get means of latent traits. + latent_traits_means = dataset["latent_traits"].mean(axis=0).values + # Get variances of coefficients attributed to latent traits + variances = dataset["latent_variances"].mean(axis=0).values + + # Select the baseline category for each trait (i.e., what is 0 and what is 1). + # We define the baseline category as the more prevalent one, so that the mean is smaller + latent_traits_ordered = np.zeros_like(latent_traits_means) + for k in range(latent_traits_means.shape[1]): + if latent_traits_means[:, k].mean() > 0.5: + latent_traits_ordered[:, k] = 1 - latent_traits_means[:, k] + else: + latent_traits_ordered[:, k] = latent_traits_means[:, k] + + # Now we need to remove "wrong" latent traits. + # By "wrong" we will understand the following: + # - It appears in too few patients. + # - It has very small variance (i.e., uncertainty of it for all patients is almost identical) + # - The variance of associated coefficients is too small. (I.e., it's inactive) + is_too_rare = np.mean(latent_traits_ordered, axis=0) < 0.05 + is_constant = np.std(latent_traits_ordered, axis=0) < 0.05 + has_zero_variance = variances < 0.1 + is_wrong = is_too_rare | is_constant | has_zero_variance + + latent_traits_ordered = latent_traits_ordered[:, ~is_wrong] + + # We have pre-selected some traits. + # However, their order may be arbitrary and inconsistent between bootstraps. + # Hence, let's order the traits from most prevalent to the least. + + order = latent_traits_ordered.mean(axis=0).argsort()[::-1] + latent_traits_ordered = latent_traits_ordered[:, order] + + np.savez( + output.latent_traits, + latent_traits=latent_traits_ordered, + is_too_rare=is_too_rare, + is_constant=is_constant, + order=order, + ) + +rule plot_structure: + input: + posterior_samples = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/posterior_samples.nc", + latent_traits = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/latent_traits.npz", + mutations = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/mutation-matrix.csv" + output: + inferred_coefficients = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/plot_inferred_coefficients.pdf", + mutations_and_traits = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/plot_mutations_and_traits.pdf" + run: + mutations = pd.read_csv(input.mutations, index_col=0) + samples = xr.open_dataset(input.posterior_samples) + storage = np.load(input.latent_traits) + + traits = storage["latent_traits"] + binarised = np.asarray(traits > 0.5, dtype=int) + + structure = (samples["coefficients_latent"] * samples["structure_latent"]).mean(axis=0).values + structure = structure[:, ~(storage["is_too_rare"] | storage["is_constant"])] + + def magic_index(arr): + idx = np.arange(len(arr)) + idx = sorted(idx, key=lambda i: "".join(map(str, arr[i]))) + return idx + + genes_order = np.argsort(mutations.mean(axis=0)) + patients_sorted = np.argsort(-mutations.mean(axis=1)) + + index = magic_index(binarised) + fig, axs = plt.subplots(3, 1, figsize=(8, 3), dpi=300, sharex=True) + + ax = axs[0] + sns.heatmap(mutations.values[patients_sorted, :].T[genes_order[-50:], :], square=False, ax=ax, cmap="Greys", cbar=False) + ax.set_yticks([]) + ax.set_xticks([]) + ax.set_ylabel("Genes") + ax.set_xlabel("Patients (ordered by number of mutations)") + + ax = axs[1] + sns.heatmap(mutations.values[index, :].T[genes_order[-50:], :], square=False, ax=ax, cmap="Greys", cbar=False) + ax.set_yticks([]) + ax.set_xticks([]) + ax.set_ylabel("Genes") + ax.set_xlabel("Patients (ordered by latent traits)") + + ax = axs[2] + if binarised.shape[1] > 0: + sns.heatmap(binarised[index, :].T, square=False, ax=ax, cmap="Blues", cbar=False) + ax.set_ylabel("Traits") + ax.set_yticks([]) + ax.set_xticks([]) + ax.set_xlabel("Patients (ordered by latent traits)") + + fig.tight_layout() + fig.savefig(output.mutations_and_traits) + + + fig, ax = plt.subplots(figsize=(8, 1.8), dpi=300) + if structure.shape[1] > 0: + sns.heatmap(structure[genes_order[-50:], :].T, cmap="bwr", center=0, vmin=-10, vmax=10, square=True, ax=ax, cbar=False, + xticklabels=genes_order[-50:].index) + ax.set_yticks([]) + ax.set_ylabel("Traits") + + fig.subplots_adjust(top=1.00, bottom=0.5) + fig.savefig(output.inferred_coefficients) + + +class CaptureStdout: + def __enter__(self): + self._original_stdout = sys.stdout + self._new_stdout = io.StringIO() + sys.stdout = self._new_stdout + return self._new_stdout + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout = self._original_stdout + + +def fit_cph_model( + design_matrix: pd.DataFrame, + summary_file: str, + coefficients_file: str, + penalizer: float, +) -> CoxPHFitter: + # Fit the model + cph = CoxPHFitter(penalizer=penalizer, l1_ratio=0.0) + cph.fit(design_matrix, duration_col='time', event_col='event') + + # Save dataframe with coefficients + cph.summary.to_csv(coefficients_file, index=True) + + # Save human-readable summary + with CaptureStdout() as captured: + cph.print_summary(decimals=3, style="ascii") + with open(summary_file, "w") as f: + f.write(captured.getvalue()) + # Return fitted model + return cph + +DESIGN_FORMULA = "event + time + gender + std_age + stage + type" + +def normalize_time(design_matrix): + time_original = design_matrix["time"].values.copy() + # design_matrix["time_original"] = time_original + design_matrix["time"] = 1e-6 + time_original / time_original.max() + +def get_simple_design_matrix(clinical: pd.DataFrame, filename: str) -> pd.DataFrame: + design_formula = DESIGN_FORMULA + design_matrix = fm.model_matrix(design_formula, clinical) + design_matrix = design_matrix.drop("Intercept", axis=1) + normalize_time(design_matrix) + + design_matrix.index = np.arange(len(design_matrix)) + design_matrix.to_csv(filename, index=True) + + return design_matrix + + +def get_extended_design_matrix(clinical: pd.DataFrame, latent_traits: np.ndarray, filename: str) -> pd.DataFrame: + assert len(clinical) == len(latent_traits), "Data sets are of different length" + + # We want to use the following values: + design_formula = DESIGN_FORMULA + + # If we have latent traits, we add them to the design matrix + traits_names = [f"Trait_{i}" for i in range(1, 1+latent_traits.shape[1])] + traits_df = pd.DataFrame(latent_traits, columns=traits_names, index=clinical.index) + + if len(traits_df.columns) > 0: + design_formula = DESIGN_FORMULA + " + " + "+".join(traits_df.columns) + + design_matrix = fm.model_matrix(design_formula, pd.concat([clinical, traits_df], axis=1)) + design_matrix = design_matrix.drop("Intercept", axis=1) + normalize_time(design_matrix) + + design_matrix.index = np.arange(len(design_matrix)) + design_matrix.to_csv(filename, index=True) + return design_matrix + + +def compute_p_value(extended: CoxPHFitter, restricted: CoxPHFitter) -> dict: + diff_df = len(extended.params_) - len(restricted.params_) + LR = 2 * (extended.log_likelihood_ - restricted.log_likelihood_) + if diff_df > 0: + p_value = chi2.sf(LR, diff_df) + else: + p_value = None + return { + "degrees_of_freedom_difference": diff_df, + "likelihood_ratio_statistic": LR, + "likelihood_ratio_p_value": p_value, + "p_value_is_numeric": p_value is not None, + } + + +def compute_c_index(extended: CoxPHFitter, restricted: CoxPHFitter) -> dict: + conc_restricted = restricted.concordance_index_ + conc_extended = extended.concordance_index_ + conc_improvement = conc_extended - conc_restricted + return { + "concordance_restricted": conc_restricted, + "concordance_extended": conc_extended, + "concordance_improvement": conc_improvement, + } + + +def compute_in_sample_calibration( + extended_model: CoxPHFitter, + extended_design: pd.DataFrame, + restricted_model: CoxPHFitter, + restricted_design: pd.DataFrame, + axs = None, +) -> dict: + if axs is None: + fig, axs = plt.subplots(1, 2) + + t0 = restricted_design["time"].max() + _ = extended_design["time"].max() + assert abs(t0 - _) < 0.01, "Time points are not the same" + + # Note that ICI and E50 are errors, i.e., lower is better + _, ici_restricted, e50_restricted = survival_probability_calibration(restricted_model, restricted_design.reset_index(), t0=t0, ax=axs[0]) + _, ici_extended, e50_extended = survival_probability_calibration(extended_model, extended_design.reset_index(), t0=t0, ax=axs[1]) + + ici_improvement = ici_restricted - ici_extended + e50_improvment = e50_restricted - e50_extended + return { + "ici_restricted": ici_restricted, + "ici_extended": ici_extended, + "ici_improvement": ici_improvement, + "e50_restricted": e50_restricted, + "e50_extended": e50_extended, + "e50_improvment": e50_improvment, + } + + +rule fit_survival: + input: + clinical = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/clinical-information.csv", + latent_traits = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/latent_traits.npz" + output: + extended_design_matrix = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/extended_design_matrix.csv", + extended_ascii_file = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/extended_summary.txt", + extended_survival_coef = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/extended_coef.csv", + restricted_design_matrix = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/restricted_design_matrix.csv", + restricted_ascii_file = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/restricted_summary.txt", + restricted_survival_coef = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/restricted_coef.csv", + difference_summary = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/difference_summary.json", + # residuals_arrays = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/residuals.npz", + residuals_plot = "generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/residuals.pdf" + run: + clinical = pd.read_csv(input.clinical, index_col=0) + latent_traits = np.load(input.latent_traits)["latent_traits"] + + restricted_design_matrix = get_simple_design_matrix(clinical, filename=output.restricted_design_matrix) + extended_design_matrix = get_extended_design_matrix(clinical=clinical, latent_traits=latent_traits, filename=output.extended_design_matrix) + + spec = ANALYSES[wildcards.analysis] + + cph_restricted = fit_cph_model( + design_matrix=restricted_design_matrix, + summary_file=output.restricted_ascii_file, + coefficients_file=output.restricted_survival_coef, + penalizer=spec.penalizer, + ) + + cph_extended = fit_cph_model( + design_matrix=extended_design_matrix, + summary_file=output.extended_ascii_file, + coefficients_file=output.extended_survival_coef, + penalizer=spec.penalizer, + ) + + basic_info = {"n_latent_traits": latent_traits.shape[1], "n_points": len(clinical)} + + fig, axs = plt.subplots(1, 2, figsize=(8, 4)) + + difference_dict = { + **basic_info, + **compute_p_value(extended=cph_extended, restricted=cph_restricted), + **compute_c_index(extended=cph_extended, restricted=cph_restricted), + **compute_in_sample_calibration( + extended_model=cph_extended, + extended_design=extended_design_matrix, + restricted_model=cph_restricted, + restricted_design=restricted_design_matrix, + axs=(axs[0], axs[1]), + ), + } + + fig.tight_layout() + fig.savefig(output.residuals_plot) + + with open(output.difference_summary, "w") as f: + json.dump(difference_dict, f) + + +rule plot_survival_difference: + input: + lambda wildcards: expand("generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/difference_summary.json", bootstrap=BOOTSTRAP_INDICES, analysis=wildcards.analysis) + output: + survival_plot = "generated/TCGA/{analysis}/summary/survival/plot.pdf", + assembled_csv = "generated/TCGA/{analysis}/summary/survival/assembled.csv" + run: + _tmp_lst = [] + for sample_path in input: + with open(sample_path) as f: + d = json.load(f) + _tmp_lst.append(d) + + assembled_df = pd.DataFrame(_tmp_lst) + assembled_df.to_csv(output.assembled_csv, index=False) + + original_len = len(assembled_df) + + # Remove samples with no latent traits + assembled_df = assembled_df.loc[assembled_df["n_latent_traits"] > 0] + new_len = len(assembled_df) + + if new_len < original_len: + print(f"Removed {original_len - new_len} samples with no latent traits") + + if new_len == 0: + print("No samples with latent traits") + raise Exception("No samples with latent traits") + + fig, axs = plt.subplots(1, 4, figsize=(16, 4), sharey=True) + fig.suptitle(f"Survival analysis summary, $N={len(assembled_df)}$") + + ax = axs[0] + ax.set_ylabel("Counts") + ax.set_xlabel("Log $p$-value") + ax.hist(np.log10(assembled_df["likelihood_ratio_p_value"].values), color="blue", alpha=0.5, bins=[-12, -10, -5, -3, np.log10(0.05), 0]) + ax.axvline(np.log10(0.05), color="black", linestyle=":", linewidth=1, alpha=0.8) + + def _plot_hist(ax, vals): + ax.hist(vals, bins=10, color="blue", alpha=0.5) + ax.axvline(0, color="black", linestyle=":", linewidth=1, alpha=0.8) + ax.axvline(np.median(vals), color="orangered", linestyle="-", linewidth=2) + + ax = axs[1] + ax.set_xlabel("$c_+$") + _plot_hist(ax, assembled_df["concordance_improvement"].values) + + ax = axs[2] + ax.set_xlabel("$\\mathrm{ICI}_+$") + _plot_hist(ax, assembled_df["ici_improvement"].values) + + ax = axs[3] + ax.set_xlabel("$\\mathrm{E50}_+$") + _plot_hist(ax, assembled_df["e50_improvment"].values) + + fig.tight_layout() + fig.savefig(output.survival_plot) + + +rule plot_latent_traits: + input: + lambda wildcards: expand("generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/difference_summary.json", bootstrap=BOOTSTRAP_INDICES, analysis=wildcards.analysis) + output: + latent_traits_plot = "generated/TCGA/{analysis}/summary/latent_traits.pdf" + run: + _tmp_lst = [] + for sample_path in input: + with open(sample_path) as f: + d = json.load(f) + _tmp_lst.append(d) + + assembled_df = pd.DataFrame(_tmp_lst) + + fig, ax = plt.subplots() + ax.set_xlabel("Number of latent traits") + ax.set_ylabel("Counts") + ax.hist(assembled_df["n_latent_traits"].values, bins=np.arange(-0.5, 9.5, 1), color="blue", alpha=0.5) + ax.set_xticks(np.arange(0, 10)) + + fig.tight_layout() + fig.savefig(output.latent_traits_plot) + + +COLORMAP = { + "Age": "darkgreen", + "Type": "maroon", + "Gender": "royalblue", + "Stage": "gold", + "Trait": "indigo", +} + +rule plot_effect_sizes_colormap: + output: "generated/TCGA/{analysis}/summary/coefficient_colormap.pdf" + run: + color_dict = COLORMAP + names = list(color_dict.keys()) + colors = [color_dict[name] for name in names] + + fig, ax = plt.subplots(figsize=(3, 2)) + + # Create a bar for each name with its associated color + for i, (name, color) in enumerate(color_dict.items()): + ax.barh(i, 1, color=color) + + ax.set_xticks([]) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names) + + # Setting title and adjusting layout + ax.set_title("Coefficients color map") + ax.set_frame_on(False) + fig.tight_layout() + fig.savefig(str(output)) + + +def _color_of_attribute(attr) -> str: + if "gender" in attr: + return COLORMAP["Gender"] + elif "type" in attr: + return COLORMAP["Type"] + elif "stage" in attr: + return COLORMAP["Stage"] + elif "std_age" in attr: + return COLORMAP["Age"] + elif "Trait" in attr: + return COLORMAP["Trait"] + else: + raise ValueError(f"Attr {attr} not known") + + +rule plot_effect_sizes: + input: + dataframes = lambda wildcards: expand("generated/TCGA/{analysis}/bootstraps/{bootstrap}/survival/extended_coef.csv", bootstrap=BOOTSTRAP_INDICES, analysis=wildcards.analysis), + colormap = "generated/TCGA/{analysis}/summary/coefficient_colormap.pdf" + output: + latent_traits_plot = "generated/TCGA/{analysis}/summary/effect_sizes.pdf" + run: + pths = sorted(input.dataframes) + n_bootstraps = len(pths) + + fig, ax = plt.subplots(figsize=(3, 12)) + + ax.set_yticks([]) + ax.set_ylim(0, 1) + ax.set_ylabel("Bootstrap sample") + ax.axvline(0, color="k", alpha=0.5, linestyle="--") + ax.set_xlabel("Coefficient value") + + for h in np.linspace(0, 1, n_bootstraps + 1): + ax.axhline(h, color="k", alpha=0.1) + + for i, pth in enumerate(pths): + df = pd.read_csv(pth) + n_attrs = len(df) + ys = i/n_bootstraps + np.linspace(0.001/n_bootstraps, 0.999/n_bootstraps, n_attrs + 2)[1:-1][::-1] + + for j, row in df.iterrows(): + x_mid = row["coef"] + x_lower = x_mid - row["coef lower 95%"] + x_upper = row["coef upper 95%"] - x_mid + color = _color_of_attribute(row["covariate"]) + ax.errorbar(row["coef"], ys[j], xerr=[[x_lower], [x_upper]], color=color, marker="o", markersize=2) + + fig.tight_layout() + fig.savefig(output.latent_traits_plot) diff --git a/workflows/_benchmark_utils.py b/workflows/_benchmark_utils.py new file mode 100644 index 0000000..098a7fe --- /dev/null +++ b/workflows/_benchmark_utils.py @@ -0,0 +1,79 @@ +from collections import defaultdict + +import numpy as np + + +def generate_dataset(n_points: int, seed: int) -> None: + rng = np.random.default_rng(seed) + + X = rng.normal(size=(n_points, 3)) + A = rng.binomial(1, p=[0.1, 0.4, 0.1, 0.5], size=(n_points, 4)) + index = np.argsort(["".join(map(str, a)) for a in A]) + A = A[index, :] + + states = np.hstack((A, X)) + n_covariates = states.shape[1] + + n_genes_per_covariate = 5 + n_additional_genes = 3 + n_genes = n_genes_per_covariate * n_covariates + n_additional_genes + + coefs = np.zeros((n_genes, n_covariates)) + effect_size = 4.0 + + for i in range(n_covariates): + coefs[i * n_genes_per_covariate : (i + 1) * n_genes_per_covariate, i] = ( + effect_size + ) + + if n_additional_genes > 0: + coefs[-n_additional_genes:, :] = effect_size * rng.binomial( + 1, 0.5, size=(n_additional_genes, n_covariates) + ) + + offset = -5 + logits = offset + np.einsum("nf,gf->ng", states, coefs) + ps = 1 / (1 + np.exp(-logits)) + Y = rng.binomial(1, ps) + + return { + "Y": Y, + "X": X, + "A": A, + "coefficients_X": coefs[:, -3:], + "mutual_information": np.array(mutual_information(A, A)), + } + + +def calculate_probabilities(samples): + counts = defaultdict(int) + total_samples = len(samples) + + for s in samples: + counts[s] += 1 + + probabilities = {k: v / total_samples for k, v in counts.items()} + return probabilities + + +def mutual_information(X_samples, Y_samples): + assert len(X_samples) == len(Y_samples), "Mismatched sample sizes" + + # Joint probabilities P(X, Y) + joint_samples = [(tuple(x), tuple(y)) for x, y in zip(X_samples, Y_samples)] + joint_probabilities = calculate_probabilities(joint_samples) + + # Marginal probabilities P(X) and P(Y) + X_probabilities = calculate_probabilities([tuple(x) for x in X_samples]) + Y_probabilities = calculate_probabilities([tuple(y) for y in Y_samples]) + + MI = 0 + + for (x, y), p_xy in joint_probabilities.items(): + p_x = X_probabilities.get(x, 0) + p_y = Y_probabilities.get(y, 0) + + if p_x > 0 and p_y > 0: + MI += p_xy * np.log2(p_xy / (p_x * p_y)) + + return MI diff --git a/workflows/benchmark.smk b/workflows/benchmark.smk new file mode 100644 index 0000000..e5e6d4a --- /dev/null +++ b/workflows/benchmark.smk @@ -0,0 +1,298 @@ +import nimfa +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +from jnotype.bmm import BernoulliMixtureGibbsSampler +from jnotype.pyramids import TwoLayerPyramidSampler +from jnotype.sampling import ListDataset + +import _benchmark_utils as utils + +matplotlib.use("Agg") + +METHODS = [ + "bmf", + "pyramids-unsupervised", + "pyramids-labeled", + "bmm", +] + +DATASETS = { + "250": 250, + "500": 500, + "1000": 1000, + "2000": 2000, +} +N_SEEDS: int = 10 + +COMPONENTS = [2, 4, 8] + +N_WARMUP: int = 3000 +N_STEPS: int = 1500 + + +workdir: "generated/benchmark/" + +rule all: + input: + arrays = expand("{method}-{components}/{dataset}/{seed}.npz", method=METHODS, components=COMPONENTS, dataset=DATASETS, seed=range(N_SEEDS)) + +rule assemble_results_mi: + input: expand("{method}-{components}/{dataset}/{seed}.npz", method=METHODS, components=COMPONENTS, dataset=DATASETS, seed=range(N_SEEDS)) + output: "results_mi.csv" + run: + results = [] + for method in METHODS: + for components in COMPONENTS: + for dataset in DATASETS: + for seed in range(N_SEEDS): + arrays = np.load(f"{method}-{components}/{dataset}/{seed}.npz") + results.append( + { + "method": method, + "components": components, + "dataset": dataset, + "seed": seed, + "mutual_information_gap": arrays["mutual_information_gap"], + } + ) + pd.DataFrame(results).to_csv(str(output), index=False) + +rule assemble_results_mse: + input: expand("pyramids-labeled-{components}/{dataset}/{seed}.npz", components=COMPONENTS, dataset=DATASETS, seed=range(N_SEEDS)) + output: "results_mse.csv" + run: + results = [] + for components in COMPONENTS: + for dataset in DATASETS: + for seed in range(N_SEEDS): + arrays = np.load(f"pyramids-labeled-{components}/{dataset}/{seed}.npz") + results.append( + { + "components": components, + "dataset": dataset, + "seed": seed, + "coefficients_X_mse": arrays["coefficients_X_mse"], + } + ) + pd.DataFrame(results).to_csv(str(output), index=False) + +rule plot_results_mse: + input: "results_mse.csv" + output: "plot_mse.pdf" + run: + df = pd.read_csv(str(input)) + df["Latent traits"] = df["components"] + + palette = { + 2: "teal", + 4: "deepskyblue", + 8: "blue", + } + + fig, ax = plt.subplots(figsize=(2.5, 2.5), dpi=350) + sns.boxplot(data=df, hue="Latent traits", x="dataset", y="coefficients_X_mse", ax=ax, palette=palette, linewidth=0.7, fliersize=0.5) + ax.set_xlabel("Number of patients") + ax.set_ylabel("Mean squared error") + ax.legend(frameon=False) + fig.tight_layout() + fig.savefig(str(output)) + + +rule plot_results_mi: + input: "results_mi.csv" + output: "plot_mi.pdf" + run: + def rename(method: str, components: int) -> str: + if method == "bmf": + return f"BMF ({components})" + elif method == "pyramids-unsupervised": + return f"UBP ({components})" + elif method == "pyramids-labeled": + return f"LBP ({components})" + elif method == "bmm": + return f"BMM ({components})" + else: + raise ValueError(f"{method} not recognized") + + df = pd.read_csv(str(input)) + df["algorithm"] = df.apply(lambda row: rename(row["method"], row["components"]), axis=1) + + color_list = [ + "dimgray", + "red", + "orange", + # + "greenyellow", + "green", + "aquamarine", + # + "teal", + "deepskyblue", + "blue", + # + "blueviolet", + "fuchsia", + "thistle", + ] + palette = dict(zip(df["algorithm"].unique(), color_list)) + + fig, ax = plt.subplots(figsize=(6, 2.5), dpi=350) + sns.boxplot(data=df, x="dataset", y="mutual_information_gap", hue="algorithm", palette=palette, ax=ax, linewidth=0.7, fliersize=0.5) + ax.legend(bbox_to_anchor=(1, 1), loc='upper left', frameon=False, ncol=2) + ax.set_ylabel("Mutual information gap") + ax.set_xlabel("Number of patients") + + fig.tight_layout() + fig.savefig(str(output)) + + +rule generate_ground_truth: + output: "ground_truth/{dataset}/{seed}.npz" + run: + n_points = DATASETS[wildcards.dataset] + seed = int(wildcards.seed) + dataset = utils.generate_dataset(n_points=n_points, seed=seed) + np.savez(str(output), **dataset) + + +rule run_bmf: + input: "ground_truth/{dataset}/{seed}.npz" + output: "bmf-{components}/{dataset}/{seed}.npz" + run: + arrays = np.load(str(input)) + Y = arrays["Y"] + n_components = int(wildcards.components) + bmf = nimfa.Bmf( + Y.T, + seed="nndsvd", + rank=n_components, + max_iter=100, + lambda_w=1.1, + lambda_h=1.1, + ) + bmf = bmf() + + features_continuous = np.asarray(bmf.coef()).T + features_discrete = features_continuous > 0.5 + + mi = utils.mutual_information(features_discrete, arrays["A"]) + np.savez( + str(output), + features_continuous=features_continuous, + features_discrete=features_discrete, + n_components=np.array(n_components), + mutual_information=np.array(mi), + mutual_information_gap=arrays["mutual_information"] - mi, + ) + + +rule run_unsupervised_pyramids: + input: "ground_truth/{dataset}/{seed}.npz" + output: "pyramids-unsupervised-{components}/{dataset}/{seed}.npz" + run: + arrays = np.load(str(input)) + Y = arrays["Y"] + n_components = int(wildcards.components) + + dataset = ListDataset(thinning=5, dimensions=TwoLayerPyramidSampler.dimensions()) + sampler = TwoLayerPyramidSampler( + datasets=[dataset], + observed=Y, + n_binary_codes=n_components, + n_clusters=2, + dirichlet_prior=np.ones(2) / 2, + warmup=N_WARMUP, + steps=N_STEPS, + verbose=False, + ) + sampler.run() + + latent_traits = (dataset.dataset["latent_traits"].mean(axis=0) > 0.5).values + mi = utils.mutual_information(latent_traits, arrays["A"]) + + np.savez( + str(output), + latent_traits=latent_traits, + n_components=np.array(n_components), + mutual_information=np.array(mi), + mutual_information_gap=arrays["mutual_information"] - mi, + ) + + +rule run_labeled_pyramids: + input: "ground_truth/{dataset}/{seed}.npz" + output: "pyramids-labeled-{components}/{dataset}/{seed}.npz" + run: + arrays = np.load(str(input)) + Y = arrays["Y"] + X = arrays["X"] + n_components = int(wildcards.components) + + dataset = ListDataset(thinning=5, dimensions=TwoLayerPyramidSampler.dimensions()) + + sampler = TwoLayerPyramidSampler( + datasets=[dataset], + observed=Y, + observed_covariates=X, + n_binary_codes=n_components, + n_clusters=2, + dirichlet_prior=np.ones(2) / 2, + warmup=N_WARMUP, + steps=N_STEPS, + verbose=False, + ) + sampler.run() + + latent_traits = (dataset.dataset["latent_traits"].mean(axis=0) > 0.5).values + mi = utils.mutual_information(latent_traits, arrays["A"]) + + coefs_inferred = dataset.dataset["coefficients_observed"].mean(axis=0).values + coefs_true = arrays["coefficients_X"] + coefs_mse = np.mean((coefs_inferred - coefs_true) ** 2) + + np.savez( + str(output), + latent_traits=latent_traits, + n_components=np.array(n_components), + mutual_information=np.array(mi), + mutual_information_gap=arrays["mutual_information"] - mi, + coefficients_X=coefs_inferred, + coefficients_X_mse=coefs_mse, + ) + + +rule run_bernoulli_mixture_model: + input: "ground_truth/{dataset}/{seed}.npz" + output: "bmm-{components}/{dataset}/{seed}.npz" + run: + arrays = np.load(str(input)) + Y = arrays["Y"] + n_components = int(wildcards.components) + + dataset = ListDataset(dimensions=BernoulliMixtureGibbsSampler.dimensions(), thinning=5) + sampler = BernoulliMixtureGibbsSampler( + datasets=[dataset], + observed_data=Y, + dirichlet_prior=np.ones(n_components), + warmup=N_WARMUP, + steps=N_STEPS, + ) + sampler.run() + + posterior_samples = dataset.dataset["labels"].values + labels = np.apply_along_axis(lambda x: np.argmax(np.bincount(x)), axis=0, arr=posterior_samples) + + mi = utils.mutual_information(labels[:, None], arrays["A"]) + + np.savez( + str(output), + posterior_samples=posterior_samples, + labels=labels, + n_components=np.array(n_components), + mutual_information=np.array(mi), + mutual_information_gap=arrays["mutual_information"] - mi, + ) diff --git a/workflows/false_positives.smk b/workflows/false_positives.smk new file mode 100644 index 0000000..5399db2 --- /dev/null +++ b/workflows/false_positives.smk @@ -0,0 +1,217 @@ +"""False positives experiment (no latent structure).""" +import dataclasses +from typing import Literal + +import matplotlib +matplotlib.use('agg') + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr + +import seaborn as sns + +from jnotype.pyramids import TwoLayerPyramidSampler, TwoLayerPyramidSamplerNonparametric +from jnotype.sampling import ListDataset + + +@dataclasses.dataclass +class Config: + # Observed factors + n_x: int + probs_x: np.ndarray + + # Number of samples + n_samples: int + + # Covariate matrix + n_genes_per_covariate: int = 5 + n_additional_genes: int = 8 + effect_size: float = 4.0 + + # MCMC setup + n_warmup: int = 4_000 + n_steps: int = 1_000 + + +CONFIGS = { + "small": Config( + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=200, + ), + "medium": Config( + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=1_000, + ), + "large": Config( + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=5_000, + ) +} + +N_SEEDS: int = 20 + +workdir: "generated/false_positives" + +rule all: + input: expand("{analysis}/variances_summary-{threshold}.npz", analysis=CONFIGS.keys(), threshold=[0, 0.01, 0.03, 0.05]) + + +def construct_required_files(analysis): + config = CONFIGS[analysis] + return [f"{analysis}/pyramids/{seed}.nc" for seed in range(N_SEEDS)] + + +rule fit_all_pyramids: + output: touch("{analysis}/pyramids_fitted.done") + input: lambda wildcards: construct_required_files(wildcards.analysis) + + +def magic_sort(x: np.ndarray) -> np.ndarray: + """Reorders the samples in a binary matrix (n_samples, n_covariates), + sorting by the binary features. + """ + def get_key(a: np.ndarray) -> str: + return "".join(map(str, a)) + idx = np.argsort(list(map(get_key, x))) + return x[idx, :] + + +rule generate_data: + output: + arrays = "{analysis}/data/{seed}.npz", + heatmap = "{analysis}/data/{seed}.pdf", + run: + config = CONFIGS[wildcards.analysis] + rng = np.random.default_rng(int(wildcards.seed)) + + # Generate (discrete) observed covariates + observed_covariates = rng.binomial(1, config.probs_x, size=(config.n_samples, config.n_x)) + # Reorder the samples + observed_covariates = magic_sort(observed_covariates) + + # Merge latent factors and observed covariates + true_characteristics = observed_covariates + + n_genes_per_covariate = config.n_genes_per_covariate + n_additional_genes = config.n_additional_genes + + n_all = config.n_x + n_genes = config.n_genes_per_covariate * n_all + n_additional_genes + coefs = np.zeros((n_genes, n_all)) + effect_size = config.effect_size + + for i in range(n_all): + coefs[i*n_genes_per_covariate:(i+1)*n_genes_per_covariate, i] = effect_size + + if n_additional_genes > 0: + coefs[-n_additional_genes:, :] = effect_size * (-1) ** rng.binomial(1, 0.5, size=(n_additional_genes, n_all)) + + offset = -5 + logits = offset + np.einsum("nf,gf->ng", true_characteristics, coefs) + ps = 1/(1 + np.exp(-logits)) + Y = rng.binomial(1, ps) + + np.savez( + output.arrays, + X=observed_covariates, + Y=Y, + coefs=coefs, + ) + + # Save figures + fig, axs = plt.subplots(1, 3) + + ax = axs[0] + sns.heatmap(true_characteristics, cmap="Blues", vmin=0, vmax=1, ax=ax, cbar=False) + ax.set_title("Patient traits $X^*$") + + ax = axs[1] + sns.heatmap(coefs, cmap="bwr", center=0, ax=ax) + ax.set_title("True coefficients") + + ax = axs[2] + sns.heatmap(Y, cmap="Blues", vmin=0, vmax=1, ax=ax, cbar=False) + ax.set_title("Gene mutations $Y$") + + fig.tight_layout() + fig.savefig(output.heatmap) + + +rule fit_pyramid: + input: "{analysis}/data/{seed}.npz" + output: + pyramid_samples="{analysis}/pyramids/{seed}.nc", + run: + input_arrays = np.load(input[0]) + X = input_arrays["X"] + Y = input_arrays["Y"] + + config = CONFIGS[wildcards.analysis] + + # Now, fit the pyramid + dataset = ListDataset(thinning=5, dimensions=TwoLayerPyramidSamplerNonparametric.dimensions()) + + sampler = TwoLayerPyramidSamplerNonparametric( + datasets=[dataset], + observed=Y, + observed_covariates=X, + dirichlet_prior=np.ones(2) / 2, + max_binary_codes=10, + expected_binary_codes=5.0, + n_clusters=2, + verbose=True, + warmup=config.n_warmup, + steps=config.n_steps, + inactive_latent_variance_theta_inf = 0.1**2, + mixing_beta_prior=(1.0, 5.0), + ) + sampler.run() + + dataset.dataset.to_netcdf(output.pyramid_samples) + + +rule calculate_variances: + input: + lambda wildcards: [f"{wildcards.analysis}/pyramids/{seed}.nc" for seed in range(N_SEEDS)] + output: + variances="{analysis}/variances_summary-{threshold}.npz" + run: + latent_variances = [] + observed_variances = [] + for inp_path in input: + samples = xr.open_dataset(inp_path) + + latent_traits_probs = samples["latent_traits"].mean(axis=0).values + # Now we need to remove "wrong" latent traits. + # By "wrong" we will understand the following: + # - It appears in too few patients. + # - It has very small variance (i.e., uncertainty of it for all patients is almost identical) + # - The variance of associated coefficients is too small. (I.e., it's inactive) + threshold = float(wildcards.threshold) + + is_too_rare = (np.mean(latent_traits_probs, axis=0) < threshold) | (np.mean(latent_traits_probs, axis=0) > 1 - threshold) + is_constant = np.std(latent_traits_probs, axis=0) < threshold + is_wrong = is_too_rare | is_constant + + n_points = latent_traits_probs.shape[0] + + lat = samples["latent_variances"].mean(axis=0).values + lat = lat * (~is_wrong) + + # Shape (n_latent_traits,) + obs = samples["observed_variances"].mean(axis=0).values + latent_variances.append(lat) + observed_variances.append(obs) + + np.savez( + output.variances, + latent_variances=np.asarray(latent_variances), + observed_variances=np.asarray(observed_variances), + threshold=np.asarray(threshold), + n_points = np.asarray(n_points), + ) diff --git a/workflows/misspecified.smk b/workflows/misspecified.smk new file mode 100644 index 0000000..ef6492e --- /dev/null +++ b/workflows/misspecified.smk @@ -0,0 +1,254 @@ +"""Misspecified model (continuous latent traits).""" +import dataclasses +from typing import Literal + +import matplotlib +matplotlib.use('agg') + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr + +import seaborn as sns + +from jnotype.pyramids import TwoLayerPyramidSampler, TwoLayerPyramidSamplerNonparametric +from jnotype.sampling import ListDataset + + +@dataclasses.dataclass +class Config: + # Latent factors + n_a: int + dist_a: Literal["uniform", "gaussian"] + + # Observed factors + n_x: int + probs_x: np.ndarray + + # Number of samples + n_samples: int + + # Covariate matrix + n_genes_per_covariate: int = 5 + n_additional_genes: int = 8 + effect_size: float = 4.0 + + # MCMC setup + n_warmup: int = 4_000 + n_steps: int = 1_000 + + +CONFIGS = { + "uniform": Config( + n_a=4, + dist_a="uniform", + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=1_000, + ), + "gaussian": Config( + n_a=4, + dist_a="gaussian", + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=1_000, + ), +} + +N_SEEDS: int = 20 + +workdir: "generated/misspecified" + +rule all: + # input: expand("{analysis}/pyramids_fitted.done", analysis=CONFIGS.keys()) + input: expand("{analysis}/similarities/{n_known}-{seed}.npz", analysis=CONFIGS.keys(), seed=range(N_SEEDS), n_known=range(4)) + + +def construct_required_files(analysis): + config = CONFIGS[analysis] + return [f"{analysis}/pyramids/{n_known}-{seed}.nc" for seed in range(N_SEEDS) for n_known in range(config.n_x + 1)] + + +rule fit_all_pyramids: + output: touch("{analysis}/pyramids_fitted.done") + input: lambda wildcards: construct_required_files(wildcards.analysis) + + +def magic_sort(x: np.ndarray) -> np.ndarray: + """Reorders the samples in a binary matrix (n_samples, n_covariates), + sorting by the binary features. + """ + def get_key(a: np.ndarray) -> str: + return "".join(map(str, a)) + idx = np.argsort(list(map(get_key, x))) + return x[idx, :] + + +rule generate_data: + output: + arrays = "{analysis}/data/{seed}.npz", + heatmap = "{analysis}/data/{seed}.pdf", + run: + config = CONFIGS[wildcards.analysis] + rng = np.random.default_rng(int(wildcards.seed)) + + # Generate (continuous) + n_a = config.n_a + if config.dist_a == "uniform": + latent_traits = rng.uniform(0, 1, size=(config.n_samples, n_a)) + elif config.dist_a == "gaussian": + latent_traits = rng.normal(0, 1, size=(config.n_samples, n_a)) + else: + raise ValueError(f"Unknown distribution {config.dist_a}") + + # Generate (discrete) observed covariates + observed_covariates = rng.binomial(1, config.probs_x, size=(config.n_samples, config.n_x)) + # Reorder the samples + observed_covariates = magic_sort(observed_covariates) + + # Merge latent factors and observed covariates + true_characteristics = np.hstack((latent_traits, observed_covariates)) + + n_genes_per_covariate = config.n_genes_per_covariate + n_additional_genes = config.n_additional_genes + + n_all = config.n_a + config.n_x + n_genes = config.n_genes_per_covariate * n_all + n_additional_genes + coefs = np.zeros((n_genes, n_all)) + effect_size = config.effect_size + + for i in range(n_all): + coefs[i*n_genes_per_covariate:(i+1)*n_genes_per_covariate, i] = effect_size + + if n_additional_genes > 0: + coefs[-n_additional_genes:, :] = effect_size * (-1) ** rng.binomial(1, 0.5, size=(n_additional_genes, n_all)) + + offset = -5 + logits = offset + np.einsum("nf,gf->ng", true_characteristics, coefs) + ps = 1/(1 + np.exp(-logits)) + Y = rng.binomial(1, ps) + + np.savez( + output.arrays, + A=latent_traits, + X=observed_covariates, + Y=Y, + coefs=coefs, + ) + + # Save figures + fig, axs = plt.subplots(1, 3) + + ax = axs[0] + sns.heatmap(true_characteristics, cmap="Blues", vmin=0, vmax=1, ax=ax, cbar=False) + ax.set_title("Patient traits $(A^*, X^*)$") + + ax = axs[1] + sns.heatmap(coefs, cmap="bwr", center=0, ax=ax) + ax.set_title("True coefficients") + + ax = axs[2] + sns.heatmap(Y, cmap="Blues", vmin=0, vmax=1, ax=ax, cbar=False) + ax.set_title("Gene mutations Y") + + fig.tight_layout() + fig.savefig(output.heatmap) + + +rule fit_pyramid: + input: "{analysis}/data/{seed}.npz" + output: + pyramid_samples="{analysis}/pyramids/{n_known}-{seed}.nc", + run: + input_arrays = np.load(input[0]) + X = input_arrays["X"] + Y = input_arrays["Y"] + + # We fit the pyramid using only selected covariates + n_known = int(wildcards.n_known) + if n_known == 0: + covariates = np.zeros((X.shape[0], 1)) + else: + covariates = X[:, :n_known].astype(np.float32) + + config = CONFIGS[wildcards.analysis] + + # Now, fit the pyramid + dataset = ListDataset(thinning=5, dimensions=TwoLayerPyramidSamplerNonparametric.dimensions()) + + sampler = TwoLayerPyramidSamplerNonparametric( + datasets=[dataset], + observed=Y, + observed_covariates=covariates, + dirichlet_prior=np.ones(2) / 2, + max_binary_codes=10, + expected_binary_codes=5.0, + n_clusters=2, + verbose=True, + warmup=config.n_warmup, + steps=config.n_steps, + inactive_latent_variance_theta_inf = 0.1**2, + mixing_beta_prior=(1.0, 5.0), + ) + sampler.run() + + dataset.dataset.to_netcdf(output.pyramid_samples) + + +rule calculate_similarities: + input: + pyramid_samples="{analysis}/pyramids/{n_known}-{seed}.nc", + data="{analysis}/data/{seed}.npz" + output: + similarities="{analysis}/similarities/{n_known}-{seed}.npz" + run: + samples = xr.open_dataset(input.pyramid_samples) + data = np.load(input.data) + + # Shape (n_latent_traits,) + latent_variances = samples["latent_variances"].mean(axis=0).values + # Shape (n_samples, n_latent_traits) + latent_traits_probs = samples["latent_traits"].mean(axis=0).values + + # Now we need to remove "wrong" latent traits. + # By "wrong" we will understand the following: + # - It appears in too few patients. + # - It has very small variance (i.e., uncertainty of it for all patients is almost identical) + # - The variance of associated coefficients is too small. (I.e., it's inactive) + is_too_rare = (np.mean(latent_traits_probs, axis=0) < 0.01) | (np.mean(latent_traits_probs, axis=0) > 0.99) + is_constant = np.std(latent_traits_probs, axis=0) < 0.01 + has_zero_variance = latent_variances < 0.05 + is_wrong = is_too_rare | is_constant | has_zero_variance + + latent_traits_probs = latent_traits_probs[:, ~is_wrong] + + # Now we have to calculate the correlations between the latent traits and the observed covariates + + def pearson_rho(a, b): + return pd.DataFrame({"a": a, "b": b}).corr(method="pearson").iloc[0, 1] + + def correlation_matrix(X, Q): + assert len(X) == len(Q) + F = X.shape[1] + K = Q.shape[1] + + arr = np.zeros((F, K)) + for f in range(F): + for k in range(K): + arr[f, k] = pearson_rho(X[:, f], Q[:, k]) + return arr + + def sort_abs(a): + return np.sort(np.abs(a), axis=1)[:, ::-1] + + def similarities(X, Q): + return sort_abs(correlation_matrix(X, Q)) + + np.savez( + output.similarities, + similarities_latent=similarities(data["A"], latent_traits_probs), + similarities_observed=similarities(data["X"], latent_traits_probs), + n_active=np.array(latent_traits_probs.shape[1]), + n_known=np.array(int(wildcards.n_known)), + ) diff --git a/workflows/more_observed.smk b/workflows/more_observed.smk new file mode 100644 index 0000000..394a5d4 --- /dev/null +++ b/workflows/more_observed.smk @@ -0,0 +1,245 @@ +"""More observed variables.""" +import dataclasses + +import matplotlib +matplotlib.use('agg') + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr + +import seaborn as sns + +from jnotype.pyramids import TwoLayerPyramidSampler, TwoLayerPyramidSamplerNonparametric +from jnotype.sampling import ListDataset + + +@dataclasses.dataclass +class Config: + # Latent factors + n_a: int + probs_a: np.ndarray + # Observed factors + n_x: int + probs_x: np.ndarray + + # Number of samples + n_samples: int + + # Covariate matrix + n_genes_per_covariate: int = 5 + n_additional_genes: int = 8 + effect_size: float = 4.0 + + # MCMC setup + n_warmup: int = 4_000 + n_steps: int = 1_000 + + +CONFIGS = { + "large_sample": Config( + n_a=4, + probs_a=np.array([0.55, 0.4, 0.25, 0.1]), + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=1_000, + ), + "small_sample": Config( + n_a=4, + probs_a=np.array([0.55, 0.4, 0.25, 0.1]), + n_x=3, + probs_x=np.array([0.55, 0.4, 0.25]), + n_samples=200, + ), +} + +N_SEEDS: int = 20 + +workdir: "generated/more_observed/" + +rule all: + # input: expand("{analysis}/pyramids_fitted.done", analysis=CONFIGS.keys()) + input: expand("{analysis}/similarities/{n_known}-{seed}.npz", analysis=CONFIGS.keys(), seed=range(N_SEEDS), n_known=range(4)) + + +def construct_required_files(analysis): + config = CONFIGS[analysis] + return [f"{analysis}/pyramids/{n_known}-{seed}.nc" for seed in range(N_SEEDS) for n_known in range(config.n_x + 1)] + + +rule fit_all_pyramids: + output: touch("{analysis}/pyramids_fitted.done") + input: lambda wildcards: construct_required_files(wildcards.analysis) + + +def magic_sort(x: np.ndarray) -> np.ndarray: + """Reorders the samples in a binary matrix (n_samples, n_covariates), + sorting by the binary features. + """ + def get_key(a: np.ndarray) -> str: + return "".join(map(str, a)) + idx = np.argsort(list(map(get_key, x))) + return x[idx, :] + + +rule generate_data: + output: + arrays = "{analysis}/data/{seed}.npz", + heatmap = "{analysis}/data/{seed}.pdf", + run: + config = CONFIGS[wildcards.analysis] + rng = np.random.default_rng(int(wildcards.seed)) + + probs_all = np.concatenate([config.probs_a, config.probs_x]) + n_all = len(probs_all) + + true_characteristics = rng.binomial(1, probs_all, size=(config.n_samples, n_all)) + # Reorder the samples + true_characteristics = magic_sort(true_characteristics) + + # Split into latent traits and observed covariates + latent_traits = true_characteristics[:, :config.n_a] + observed_covariates = true_characteristics[:, config.n_a:] + + n_genes_per_covariate = config.n_genes_per_covariate + n_additional_genes = config.n_additional_genes + + n_genes = config.n_genes_per_covariate * n_all + n_additional_genes + coefs = np.zeros((n_genes, n_all)) + effect_size = config.effect_size + + for i in range(n_all): + coefs[i*n_genes_per_covariate:(i+1)*n_genes_per_covariate, i] = effect_size + + if n_additional_genes > 0: + coefs[-n_additional_genes:, :] = effect_size * (-1) ** rng.binomial(1, 0.5, size=(n_additional_genes, n_all)) + + offset = -5 + logits = offset + np.einsum("nf,gf->ng", true_characteristics, coefs) + ps = 1/(1 + np.exp(-logits)) + Y = rng.binomial(1, ps) + + np.savez( + output.arrays, + A=latent_traits, + X=observed_covariates, + Y=Y, + coefs=coefs, + ) + + # Save figures + fig, axs = plt.subplots(1, 3) + + ax = axs[0] + sns.heatmap(true_characteristics, cmap="Blues", vmin=0, vmax=1, ax=ax, cbar=False) + ax.set_title("Patient traits $(A^*, X^*)$") + + ax = axs[1] + sns.heatmap(coefs, cmap="bwr", center=0, ax=ax) + ax.set_title("True coefficients") + + ax = axs[2] + sns.heatmap(Y, cmap="Blues", vmin=0, vmax=1, ax=ax, cbar=False) + ax.set_title("Gene mutations Y") + + fig.tight_layout() + fig.savefig(output.heatmap) + + +rule fit_pyramid: + input: "{analysis}/data/{seed}.npz" + output: + pyramid_samples="{analysis}/pyramids/{n_known}-{seed}.nc", + run: + input_arrays = np.load(input[0]) + X = input_arrays["X"] + Y = input_arrays["Y"] + + # We fit the pyramid using only selected covariates + n_known = int(wildcards.n_known) + if n_known == 0: + covariates = np.zeros((X.shape[0], 1)) + else: + covariates = X[:, :n_known].astype(np.float32) + + config = CONFIGS[wildcards.analysis] + + # Now, fit the pyramid + dataset = ListDataset(thinning=5, dimensions=TwoLayerPyramidSamplerNonparametric.dimensions()) + + sampler = TwoLayerPyramidSamplerNonparametric( + datasets=[dataset], + observed=Y, + observed_covariates=covariates, + dirichlet_prior=np.ones(2) / 2, + max_binary_codes=10, + expected_binary_codes=5.0, + n_clusters=2, + verbose=True, + warmup=config.n_warmup, + steps=config.n_steps, + inactive_latent_variance_theta_inf = 0.1**2, + mixing_beta_prior=(1.0, 5.0), + ) + sampler.run() + + dataset.dataset.to_netcdf(output.pyramid_samples) + + +rule calculate_similarities: + input: + pyramid_samples="{analysis}/pyramids/{n_known}-{seed}.nc", + data="{analysis}/data/{seed}.npz" + output: + similarities="{analysis}/similarities/{n_known}-{seed}.npz" + run: + samples = xr.open_dataset(input.pyramid_samples) + data = np.load(input.data) + + # Shape (n_latent_traits,) + latent_variances = samples["latent_variances"].mean(axis=0).values + # Shape (n_samples, n_latent_traits) + latent_traits_probs = samples["latent_traits"].mean(axis=0).values + + # Now we need to remove "wrong" latent traits. + # By "wrong" we will understand the following: + # - It appears in too few patients. + # - It has very small variance (i.e., uncertainty of it for all patients is almost identical) + # - The variance of associated coefficients is too small. (I.e., it's inactive) + is_too_rare = (np.mean(latent_traits_probs, axis=0) < 0.01) | (np.mean(latent_traits_probs, axis=0) > 0.99) + is_constant = np.std(latent_traits_probs, axis=0) < 0.01 + has_zero_variance = latent_variances < 0.05 + is_wrong = is_too_rare | is_constant | has_zero_variance + + latent_traits_probs = latent_traits_probs[:, ~is_wrong] + + # Now we have to calculate the correlations between the latent traits and the observed covariates + + def pearson_rho(a, b): + return pd.DataFrame({"a": a, "b": b}).corr(method="pearson").iloc[0, 1] + + def correlation_matrix(X, Q): + assert len(X) == len(Q) + F = X.shape[1] + K = Q.shape[1] + + arr = np.zeros((F, K)) + for f in range(F): + for k in range(K): + arr[f, k] = pearson_rho(X[:, f], Q[:, k]) + return arr + + def sort_abs(a): + return np.sort(np.abs(a), axis=1)[:, ::-1] + + def similarities(X, Q): + return sort_abs(correlation_matrix(X, Q)) + + np.savez( + output.similarities, + similarities_latent=similarities(data["A"], latent_traits_probs), + similarities_observed=similarities(data["X"], latent_traits_probs), + n_active=np.array(latent_traits_probs.shape[1]), + n_known=np.array(int(wildcards.n_known)), + ) diff --git a/workflows/simple_tables.smk b/workflows/simple_tables.smk new file mode 100644 index 0000000..df9ef95 --- /dev/null +++ b/workflows/simple_tables.smk @@ -0,0 +1,32 @@ +"""Simple plots and tables for illustratory purposes.""" +import matplotlib +matplotlib.use('agg') + +import seaborn as sns +import numpy as np +import pandas as pd + + +workdir: "generated/simple_tables/" + +rule all: + input: "identifiability_Bernoulli_mixture.txt" + + +rule identifiability_bernoulli_mixture: + output: + text_file = "identifiability_Bernoulli_mixture.txt", + latex_file = "identifiability_Bernoulli_mixture.tex" + run: + b = np.arange(2, 10) + k = 2 * np.ceil(np.log2(b)) + 1 + k = k.astype(int) + + df = pd.DataFrame({ + "Mixture components": b, + "Required features": k, + }) + + df.to_latex(output.latex_file, index=False, index_names=False) + df.to_csv(output.text_file, sep="\t", index=False) +