Skip to content

Commit

Permalink
Nonparametric variant of labeled Bayesian pyramids (#22)
Browse files Browse the repository at this point in the history
* Prototype of the nonparametric pyramids

* Add to public API

* Add JIT

* WIP: Add workflows for the figures.

* Add visualisation

* Remove cancer type annotations

* Fix wrong argument

* Remove unnecessary import

* Run formatting

* Run formatting

* Run formatting

* Run formatting

* Update random key annotation

* Fix minor error

* Fix minor error
  • Loading branch information
pawel-czyz authored Feb 23, 2024
1 parent 8c0ad8b commit 86c3bfc
Show file tree
Hide file tree
Showing 37 changed files with 2,293 additions and 40 deletions.
15 changes: 8 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
repos:
- repo: https://github.com/python-poetry/poetry
rev: '1.3.2'
hooks:
- id: poetry-check
- id: poetry-lock
- id: poetry-export
args: ["-f", "requirements.txt", "-o", "requirements.txt"]
#- repo: https://github.com/python-poetry/poetry
#rev: '1.3.2'
#hooks:
#- id: poetry-check
#- id: poetry-lock
#- id: poetry-export
#args: ["-f", "requirements.txt", "-o", "requirements.txt"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.245'
hooks:
Expand All @@ -21,6 +21,7 @@ repos:
- id: interrogate
exclude: tests
args: [--config=pyproject.toml, --fail-under=95]
pass_filenames: false
#- repo: https://github.com/RobertCraigie/pyright-python
#rev: v1.1.296
#hooks:
Expand Down
5 changes: 1 addition & 4 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""PyTest unit tests configuration file."""

import dataclasses

import pytest

import matplotlib

matplotlib.use("Agg")


@dataclasses.dataclass
class TurnOnTestSuiteArgument:
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Exploratory analysis of binary data."""

import jnotype.bmm as bmm
import jnotype.datasets as datasets
import jnotype.sampling as sampling
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_csp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Cumulative shrinkage prior."""

import jax
import jax.numpy as jnp
import jax.scipy as jsp
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_factor_analysis/_gibbs_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Sampling steps for all variables, apart from variances
attributed to latent traits, which are sampled with CSP module."""

from typing import Callable

import jax
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_factor_analysis/_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling from the posterior distribution."""

from typing import Callable

from jaxtyping import Float, Array
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/_factor_analysis/_simulate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simulate data sets."""

from jaxtyping import Float, Array

import jax.numpy as jnp
Expand Down
5 changes: 3 additions & 2 deletions src/jnotype/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This file should be as small as possible.
Appearing themes should be refactored and placed
into separate modules."""

import jax


Expand All @@ -16,15 +17,15 @@ class JAXRNG:
b = jax.random.bernoulli(rng.key, shape=(10,))
"""

def __init__(self, key: jax.random.PRNGKeyArray) -> None:
def __init__(self, key: jax.Array) -> None:
"""
Args:
key: initialization key
"""
self._key = key

@property
def key(self) -> jax.random.PRNGKeyArray:
def key(self) -> jax.Array:
"""Generates a new key."""
key, subkey = jax.random.split(self._key)
self._key = key
Expand Down
5 changes: 3 additions & 2 deletions src/jnotype/_variance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for sampling variances."""

from jax import random
import jax
import jax.numpy as jnp
Expand All @@ -7,7 +8,7 @@


def _sample_precisions(
key: random.PRNGKeyArray,
key: jax.Array,
values: Float[Array, "G F"],
mask: Int[Array, "G F"],
prior_shape: float,
Expand All @@ -34,7 +35,7 @@ def _sample_precisions(

@jax.jit
def sample_variances(
key: random.PRNGKeyArray,
key: jax.Array,
values: Float[Array, "G F"],
mask: Int[Array, "G F"],
prior_shape: float = 2.0,
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/bmm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Bernoulli Mixture Model."""

from jnotype.bmm._em import expectation_maximization
from jnotype.bmm._gibbs import (
sample_mixing,
Expand Down
11 changes: 8 additions & 3 deletions src/jnotype/bmm/_em.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The Expectation-Maximization algorithm for Bernoulli Mixture Model."""

import dataclasses
import time
from typing import Optional
Expand Down Expand Up @@ -68,7 +69,11 @@ def em_step(
observed: Int[Array, "N K"],
mixing: Float[Array, "K B"],
proportions: Float[Array, " B"],
) -> tuple[Float[Array, "N B"], Float[Array, "K B"], Float[Array, " B"],]:
) -> tuple[
Float[Array, "N B"],
Float[Array, "K B"],
Float[Array, " B"],
]:
"""The E and M step combined, for better JIT compiler optimisation.
Args:
Expand Down Expand Up @@ -160,7 +165,7 @@ def _init(
*,
n_features: int,
n_clusters: Optional[int], # type: ignore
key: Optional[jax.random.PRNGKeyArray],
key: Optional[jax.Array],
mixing_init: Optional[Float[Array, "K B"]],
proportions_init: Optional[Float[Array, " B"]],
) -> tuple[Float[Array, "K B"], Float[Array, " B"]]:
Expand Down Expand Up @@ -221,7 +226,7 @@ def expectation_maximization(
observed: Int[Array, "K B"],
*,
n_clusters: Optional[int] = None,
key: Optional[jax.random.PRNGKeyArray] = None,
key: Optional[jax.Array] = None,
mixing_init: Optional[Float[Array, "K B"]] = None,
proportions_init: Optional[Float[Array, " B"]] = None,
max_n_steps: int = 10_000,
Expand Down
11 changes: 6 additions & 5 deletions src/jnotype/bmm/_gibbs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling cluster labels and proportions."""

from typing import Optional, Sequence

import jax
Expand Down Expand Up @@ -41,7 +42,7 @@ def _bernoulli_loglikelihood(


def sample_cluster_labels(
key: random.PRNGKeyArray,
key: jax.Array,
*,
cluster_proportions: Float[Array, " B"],
mixing: Float[Array, "K B"],
Expand Down Expand Up @@ -94,7 +95,7 @@ def _calculate_counts(labels: Int[Array, " N"], n_clusters: int) -> Int[Array, "


def sample_cluster_proportions(
key: random.PRNGKeyArray,
key: jax.Array,
*,
labels: Int[Array, " N"],
dirichlet_prior: Float[Array, " B"],
Expand All @@ -117,7 +118,7 @@ def sample_cluster_proportions(


def sample_mixing(
key: random.PRNGKeyArray,
key: jax.Array,
*,
observations: Int[Array, "N K"],
labels: Int[Array, " N"],
Expand Down Expand Up @@ -148,7 +149,7 @@ def sample_mixing(
@jax.jit
def single_sampling_step(
*,
key: random.PRNGKeyArray,
key: jax.Array,
observed_data: Int[Array, "N K"],
proportions: Float[Array, " B"],
mixing: Float[Array, "K B"],
Expand Down Expand Up @@ -210,7 +211,7 @@ def __init__(
dirichlet_prior: Float[Array, " cluster"],
beta_prior: tuple[float, float] = (1.0, 1.0),
*,
jax_rng_key: Optional[jax.random.PRNGKeyArray] = None,
jax_rng_key: Optional[jax.Array] = None,
warmup: int = 2000,
steps: int = 3000,
verbose: bool = False,
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/checks/_histograms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Plotting histograms of data."""

from typing import Sequence, Union, Literal

import matplotlib.pyplot as plt
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data sets."""

from jnotype.datasets._simulation import BlockImagesSampler

__all__ = [
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/datasets/_simulation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Simulated data sets."""

from jnotype.datasets._simulation._block_images import BlockImagesSampler

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion src/jnotype/datasets/_simulation/_block_images.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Simulation of binary images using Bernoulli mixture model."""

from typing import Optional

from jaxtyping import Array, Float, Int
import jax.numpy as jnp
import jax
from jax import random


Expand Down Expand Up @@ -139,7 +141,7 @@ def mixing(self) -> Float[Array, "n_classes 6 6"]:

def sample_dataset(
self,
key: random.PRNGKeyArray,
key: jax.Array,
n_samples: int,
*,
probs: Optional[Float[Array, " n_classes"]] = None,
Expand Down
3 changes: 2 additions & 1 deletion src/jnotype/logistic/_binary_latent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sample binary latent variables."""

from functools import partial

import jax
Expand All @@ -13,7 +14,7 @@
@partial(jax.jit, static_argnames="n_binary_codes")
def sample_binary_codes(
*,
key: random.PRNGKeyArray,
key: jax.Array,
intercepts: Float[Array, " features"],
coefficients: Float[Array, "features covs"],
structure: Int[Array, "features covs"],
Expand Down
13 changes: 7 additions & 6 deletions src/jnotype/logistic/_polyagamma.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logistic regression sampling utilities using Pólya-Gamma augmentation."""

from jax import random
import jax
import jax.numpy as jnp
Expand All @@ -22,7 +23,7 @@ def _calculate_logits(

def _sample_coefficients(
*,
key: random.PRNGKeyArray,
key: jax.Array,
omega: Float[Array, "points features"],
covariates: Float[Array, "points covariates"],
structure: Int[Array, "features covariates"],
Expand Down Expand Up @@ -53,9 +54,9 @@ def _sample_coefficients(
precision_matrices: Float[Array, "features covariates covariates"] = jax.vmap(
jnp.diag
)(jnp.reciprocal(prior_variance))
posterior_covariances: Float[
Array, "features covariates covariates"
] = jnp.linalg.inv(x_omega_x + precision_matrices)
posterior_covariances: Float[Array, "features covariates covariates"] = (
jnp.linalg.inv(x_omega_x + precision_matrices)
)

kappa: Float[Array, "points features"] = jnp.asarray(observed, dtype=float) - 0.5

Expand All @@ -79,7 +80,7 @@ def _sample_coefficients(

def sample_coefficients(
*,
jax_key: random.PRNGKeyArray,
jax_key: jax.Array,
numpy_rng: np.random.Generator,
observed: Int[Array, "points features"],
design_matrix: Float[Array, "points covariates"],
Expand Down Expand Up @@ -214,7 +215,7 @@ def _augment_matrices(

def sample_intercepts_and_coefficients(
*,
jax_key: random.PRNGKeyArray,
jax_key: jax.Array,
numpy_rng: np.random.Generator,
observed: Int[Array, "points features"],
intercepts: Float[Array, " features"],
Expand Down
7 changes: 4 additions & 3 deletions src/jnotype/logistic/_structure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sample structure (spike/slab distinction) variables."""

from typing import Union

import jax
Expand Down Expand Up @@ -41,7 +42,7 @@ def _logpdf_gaussian(
@jax.jit
def sample_structure(
*,
key: random.PRNGKeyArray,
key: jax.Array,
intercepts: Float[Array, " features"],
coefficients: Float[Array, "features covs"],
structure: Int[Array, "features covs"],
Expand Down Expand Up @@ -158,7 +159,7 @@ def body_fun(

@jax.jit
def sample_gamma(
key: random.PRNGKeyArray,
key: jax.Array,
structure: Int[Array, "G K"],
prior_a: float = 1.0,
prior_b: float = 1.0,
Expand All @@ -177,7 +178,7 @@ def sample_gamma(

@jax.jit
def sample_gamma_individual(
key: random.PRNGKeyArray,
key: jax.Array,
structure: Int[Array, "G K"],
prior_a: Float[Array, " K"],
prior_b: Float[Array, " K"],
Expand Down
1 change: 1 addition & 0 deletions src/jnotype/logistic/logreg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logistic regression utilities."""

import jax
import jax.numpy as jnp
from jaxtyping import Int, Float, Array
Expand Down
2 changes: 2 additions & 0 deletions src/jnotype/pyramids/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Gibbs samplers for Bayesian pyramids."""

from jnotype.pyramids._sampler_fixed import TwoLayerPyramidSampler
from jnotype.pyramids._sampler_csp import TwoLayerPyramidSamplerNonparametric

__all__ = [
"TwoLayerPyramidSampler",
"TwoLayerPyramidSamplerNonparametric",
]
Loading

0 comments on commit 86c3bfc

Please sign in to comment.