Skip to content

Commit

Permalink
Add smoke tests for divergences
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Sep 24, 2024
1 parent 0f7fa9f commit 8d6e8f5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/jnotype/energy/_dfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def discrete_fisher_divergence(log_q: LogProbFn, ys: DataSet) -> Float[Array, "
def _besag_pseudolikelihood_onepoint(
log_q: LogProbFn, y: DataPoint
) -> Float[Array, " "]:
"""Calculates Julian Besag's pseudolikelihood on a single data point.
"""Calculates Julian Besag's pseudo-log-likelihood on a single data point.
Namely,
$$
Expand All @@ -76,9 +76,9 @@ def besag_pseudolikelihood_sum(
log_q: LogProbFn,
ys: DataSet,
) -> Float[Array, " "]:
"""Besag pseudolikelihood calculated over the whole data set.
"""Besag pseudo-log-likelihood calculated over the whole data set.
Note that pseudolikelihood is additive.
Note that pseudo-log-likelihood is additive.
"""
n_points = ys.shape[0]
return n_points * besag_pseudolikelihood_mean(log_q, ys)
Expand All @@ -88,10 +88,10 @@ def besag_pseudolikelihood_mean(
log_q: LogProbFn,
ys: DataSet,
) -> Float[Array, " "]:
"""Average Besag pseudolikelihood.
"""Average Besag pseudo-log-likelihood.
Note:
As the pseudolikelihood is additive, for generalised
As the pseudo-log-likelihood is additive, for generalised
Bayesian inference one should multiply by the data set size.
See Also:
Expand Down
56 changes: 56 additions & 0 deletions tests/energy/test_dfd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Tests for discrete Fisher divergence ans pseudolikelihood methods."""

import jax.numpy as jnp

import pytest
import jnotype.energy._dfd as dfd


def linear_model(params, y):
"""Linear (independent) model.
Args:
params: vector of shape (G,)
y: data point of shape (G,)
"""
return jnp.sum(params * y)


def quadratic_model(params, y):
"""Quadratic (Ising) model.
Args:
params: matrix of shape (G, G)
y: data point of shape (G,)
"""
return jnp.einsum("ij,i,j->", params, y, y)


SETTINGS = [
(jnp.zeros(3), linear_model),
(jnp.zeros(5), linear_model),
(jnp.zeros((3, 3)), quadratic_model),
(jnp.zeros((5, 5)), quadratic_model),
]


@pytest.mark.parametrize("setting", SETTINGS)
@pytest.mark.parametrize(
"divergence",
[
dfd.discrete_fisher_divergence,
dfd.besag_pseudolikelihood_mean,
dfd.besag_pseudolikelihood_sum,
],
)
def test_quasidivergence_smoke_test(setting, divergence, n_points: int = 2):
theta, model = setting
G = theta.shape[0]

def logq(y):
return model(theta, y)

ys = jnp.zeros((n_points, G), dtype=int)

value = divergence(logq, ys)
assert value.shape == ()

0 comments on commit 8d6e8f5

Please sign in to comment.