Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilities for training energy-based models #32

Merged
merged 6 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/jnotype/energy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Energy-based models."""

from jnotype.energy._prior import (
create_symmetric_interaction_matrix,
number_of_interactions_quadratic,
)
from jnotype.energy._dfd import (
discrete_fisher_divergence,
besag_pseudolikelihood_sum,
besag_pseudolikelihood_mean,
)

__all__ = [
"create_symmetric_interaction_matrix",
"number_of_interactions_quadratic",
"discrete_fisher_divergence",
"besag_pseudolikelihood_sum",
"besag_pseudolikelihood_mean",
]
102 changes: 102 additions & 0 deletions src/jnotype/energy/_dfd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from functools import partial
from typing import Callable, Union
from jaxtyping import Float, Int, Array

import jax
import jax.numpy as jnp


Integer = Union[int, Int[Array, " "]]
DataPoint = Int[Array, " G"]
DataSet = Int[Array, "N G"]

LogProbFn = Callable[[DataPoint], Float[Array, " "]]


def bitflip(g: Integer, y: DataPoint) -> DataPoint:
"""Flips a bit at site `g`."""
return y.at[g].set(1 - y[g])


def _dfd_onepoint(log_q: LogProbFn, y: DataPoint) -> Float[Array, " "]:
"""Calculates the discrete Fisher divergence on a single data point.

Args:
log_q: unnormalized log-probability function
y: data point on which it should be evaluated
"""
log_qy = log_q(y)

def log_q_flip_fn(g: Union[int, Int[Array, " "]]):
return log_q(bitflip(g, y))

log_qflipped = jax.vmap(log_q_flip_fn)(jnp.arange(y.shape[0]))
log_ratio = log_qflipped - log_qy
return jnp.sum(jnp.exp(2 * log_ratio) - 2 * jnp.exp(-log_ratio))


def discrete_fisher_divergence(log_q: LogProbFn, ys: DataSet) -> Float[Array, " "]:
"""Evaluates the discrete Fisher divergence between the model distribution
and the empirical distribution.

Note:
When using in generalised Bayesian inference framework,
remember that the update is multiplied by the data set size
(and the temperature), i.e.,
$$
P(\\theta | data) \\propto P(\\theta) * \\exp( -\\tau N DFD )
$$
"""
f = partial(_dfd_onepoint, log_q)
return jnp.mean(jax.vmap(f)(ys))


def _besag_pseudolikelihood_onepoint(
log_q: LogProbFn, y: DataPoint
) -> Float[Array, " "]:
"""Calculates Julian Besag's pseudo-log-likelihood on a single data point.

Namely,
$$
\\log L &= \\sum_g \\log P(Y[i] = y[i] | Y[~i] = y[~i] )
&= \\sum_g \\log P(Y = y) - \\log( P(Y = y) + P(Y = bitflip(g, y) ))
$$
"""
log_qy = log_q(y)

def log_denominator(g: Union[int, Int[Array, " "]]):
log_bitflipped = log_q(bitflip(g, y))
return jnp.logaddexp(log_qy, log_bitflipped)

log_denominators = jax.vmap(log_denominator)(jnp.arange(y.shape[0]))
return jnp.sum(log_qy - log_denominators)


def besag_pseudolikelihood_sum(
log_q: LogProbFn,
ys: DataSet,
) -> Float[Array, " "]:
"""Besag pseudo-log-likelihood calculated over the whole data set.

Note that pseudo-log-likelihood is additive.
"""
n_points = ys.shape[0]
return n_points * besag_pseudolikelihood_mean(log_q, ys)


def besag_pseudolikelihood_mean(
log_q: LogProbFn,
ys: DataSet,
) -> Float[Array, " "]:
"""Average Besag pseudo-log-likelihood.

Note:
As the pseudo-log-likelihood is additive, for generalised
Bayesian inference one should multiply by the data set size.

See Also:
`discrete_fisher_divergence`, which also requires multiplication
by the data set size
"""
f = partial(_besag_pseudolikelihood_onepoint, log_q)
return jnp.mean(jax.vmap(f)(ys))
25 changes: 25 additions & 0 deletions src/jnotype/energy/_prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Priors for interpretable energy-based models."""

import jax.numpy as jnp
from jaxtyping import Float, Array


def number_of_interactions_quadratic(G: int) -> int:
"""Number of interactions in a quadratic energy model,
namely G over 2."""
return G * (G - 1) // 2


def create_symmetric_interaction_matrix(
diagonal: Float[Array, " G"],
offdiagonal: Float[Array, " G*(G-1)//2"],
) -> Float[Array, " G G"]:
"""Generates a symmetric matrix out of one-dimensional
diagonal and offdiagonal entries."""
G = diagonal.shape[0]
S = jnp.zeros((G, G), dtype=diagonal.dtype)
i1, i2 = jnp.triu_indices(G, 1)

S = S.at[i1, i2].set(offdiagonal)
S = S + S.T
return jnp.diag(diagonal) + S
166 changes: 166 additions & 0 deletions src/jnotype/energy/_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Markov chain Monte Carlo sampling of binary vectors from energy-based models."""

from typing import Callable

import jax
import jax.random as jrandom
import jax.numpy as jnp

from jaxtyping import Array, Int


def generate_all_binary_vectors(G: int) -> Int[Array, "2**G G"]:
"""Generates an array of shape (2**G, G) with all binary vectors of length G."""
return jnp.array(jnp.indices((2,) * G).reshape(G, -1).T)


_DataPoint = Int[Array, " G"]
_UnnormLogProb = Callable[
[_DataPoint], float
] # Unnormalised log-probability function maps a binary vector of shape (G,) to a float


def categorical_exact_sampling(
key: jax.Array,
n_samples: int,
G: int,
log_prob_fn: _UnnormLogProb,
) -> Int[Array, "n_samples G"]:
"""Samples from an energy-based model by constructing all
`2**G` atoms of the categorical distribution.

Note:
G has to be small (e.g., 10), so that a memory overflow does not happen
"""
binary_vectors = generate_all_binary_vectors(G)

log_probs = jax.vmap(log_prob_fn)(binary_vectors) # Unnormalized log-probabilities

categorical_dist = jrandom.categorical(key, log_probs, shape=(n_samples,))
return binary_vectors[categorical_dist]


def _gibbs_bitflip(
key: jax.Array,
log_prob_fn: _UnnormLogProb,
y: _DataPoint,
idx: int,
) -> _DataPoint:
"""Applies a Gibbs sampling step for bit at location `idx`."""
# Calculate the probabilities for both choices of this bit
y0 = y.at[idx].set(0)
y1 = y.at[idx].set(1)

ys = jnp.vstack((y0, y1))
logits = jax.vmap(log_prob_fn)(ys)

chosen_index = jrandom.categorical(key, logits=logits)
return ys[chosen_index]


def _random_site_bitflip(
key: jax.Array,
log_prob_fn: _UnnormLogProb,
y: _DataPoint,
) -> _DataPoint:
"""Samples a single bit in the Ising model."""
G = y.shape[0]
# Pick a random index to update
key1, key2 = jrandom.split(key)
idx = jrandom.randint(key1, shape=(), minval=0, maxval=G)

return _gibbs_bitflip(key=key2, log_prob_fn=log_prob_fn, y=y, idx=idx)


def construct_random_bitfip_kernel(log_prob_fn: _UnnormLogProb):
"""Constructs a kernel resampling a random site."""

def kernel(key, y):
"""Kernel flipping a random bit."""
return _random_site_bitflip(log_prob_fn=log_prob_fn, key=key, y=y)

return jax.jit(kernel)


def construct_systematic_bitflip_kernel(log_prob_fn: _UnnormLogProb):
"""Constructs a kernel systematically resampling bits one-after-another."""

def kernel(key, y):
"""Kernel systematically flipping all bits one-after-another."""

def f(state, idx: int):
"""Auxiliary function performing Gibbs bitflip at a specified site with
folding-in the site into the key."""
subkey = jrandom.fold_in(key, idx)
new_state = _gibbs_bitflip(
key=subkey, log_prob_fn=log_prob_fn, y=state, idx=idx
)
return new_state, None

new_state, _ = jax.lax.scan(f, y, jnp.arange(y.shape[0]))
return new_state

return kernel


def _gibbs_blockflip(
key: jax.Array,
log_prob_fn: _UnnormLogProb,
y: _DataPoint,
sites: Int[Array, " size"],
) -> _DataPoint:
"""Performs a blocked Gibbs update, jointly
resampling all bits at `sites`.

Note:
This requires `2**len(sites)` evaluations
of the log-probability (and similar memory),
so that not too many sites should be jointly
updated.
"""
# Generate all possible configurations for the block
block_size = sites.shape[0]
all_configs = generate_all_binary_vectors(block_size)

# Generate (unnormalized) log-probs for all possible configurations:
def logp(config):
y_candidate = y.at[sites].set(config)
return log_prob_fn(y_candidate)

logits = jax.vmap(logp)(all_configs)

# Select the new configuration
new_block_idx = jax.random.categorical(key, logits=logits)
new_config = all_configs[new_block_idx]

return y.at[sites].set(new_config)


def construct_random_blockflip_kernel(log_prob_fn: _UnnormLogProb, block_size: int):
"""Constructs a kernel resampling a random block of size `block_size`.

Note:
One requires `2**block_size` evaluations of the log-probability function
(and similar memory complexity), so that only relatively small blocks
can be used.
"""

def kernel(key, y):
"""Kernel flipping a random bit."""
G = y.shape[0]

key1, key2 = jrandom.split(key)
sites = jrandom.choice(
key1,
G,
shape=(block_size,),
replace=False,
)
return _gibbs_blockflip(
key=key2, # type: ignore
log_prob_fn=log_prob_fn,
y=y,
sites=sites,
)

return jax.jit(kernel)
9 changes: 9 additions & 0 deletions src/jnotype/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,19 @@
)
from jnotype.sampling._sampler import AbstractGibbsSampler

from jnotype.sampling._jax_loop import (
sampling_loop,
compose_kernels,
iterated_kernel_thinning,
)

__all__ = [
"DatasetInterface",
"ListDataset",
"XArrayChunkedDataset",
"AbstractChunkedDataset",
"AbstractGibbsSampler",
"sampling_loop",
"compose_kernels",
"iterated_kernel_thinning",
]
2 changes: 2 additions & 0 deletions src/jnotype/sampling/_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def _extract_samples(self, label: str) -> np.ndarray:

@property
def dataset(self) -> xr.Dataset:
"""Generates an xarray data set."""

attrs = {
"thinning": self.thinning,
} | self._attrs
Expand Down
Loading
Loading