From 8d6e8f580e6156075d98957ae4abc79df31d9ae7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 24 Sep 2024 18:04:02 +0200 Subject: [PATCH] Add smoke tests for divergences --- src/jnotype/energy/_dfd.py | 10 +++---- tests/energy/test_dfd.py | 56 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 tests/energy/test_dfd.py diff --git a/src/jnotype/energy/_dfd.py b/src/jnotype/energy/_dfd.py index 2b8be18..5004f9b 100644 --- a/src/jnotype/energy/_dfd.py +++ b/src/jnotype/energy/_dfd.py @@ -54,7 +54,7 @@ def discrete_fisher_divergence(log_q: LogProbFn, ys: DataSet) -> Float[Array, " def _besag_pseudolikelihood_onepoint( log_q: LogProbFn, y: DataPoint ) -> Float[Array, " "]: - """Calculates Julian Besag's pseudolikelihood on a single data point. + """Calculates Julian Besag's pseudo-log-likelihood on a single data point. Namely, $$ @@ -76,9 +76,9 @@ def besag_pseudolikelihood_sum( log_q: LogProbFn, ys: DataSet, ) -> Float[Array, " "]: - """Besag pseudolikelihood calculated over the whole data set. + """Besag pseudo-log-likelihood calculated over the whole data set. - Note that pseudolikelihood is additive. + Note that pseudo-log-likelihood is additive. """ n_points = ys.shape[0] return n_points * besag_pseudolikelihood_mean(log_q, ys) @@ -88,10 +88,10 @@ def besag_pseudolikelihood_mean( log_q: LogProbFn, ys: DataSet, ) -> Float[Array, " "]: - """Average Besag pseudolikelihood. + """Average Besag pseudo-log-likelihood. Note: - As the pseudolikelihood is additive, for generalised + As the pseudo-log-likelihood is additive, for generalised Bayesian inference one should multiply by the data set size. See Also: diff --git a/tests/energy/test_dfd.py b/tests/energy/test_dfd.py new file mode 100644 index 0000000..8cd07a8 --- /dev/null +++ b/tests/energy/test_dfd.py @@ -0,0 +1,56 @@ +"""Tests for discrete Fisher divergence ans pseudolikelihood methods.""" + +import jax.numpy as jnp + +import pytest +import jnotype.energy._dfd as dfd + + +def linear_model(params, y): + """Linear (independent) model. + + Args: + params: vector of shape (G,) + y: data point of shape (G,) + """ + return jnp.sum(params * y) + + +def quadratic_model(params, y): + """Quadratic (Ising) model. + + Args: + params: matrix of shape (G, G) + y: data point of shape (G,) + """ + return jnp.einsum("ij,i,j->", params, y, y) + + +SETTINGS = [ + (jnp.zeros(3), linear_model), + (jnp.zeros(5), linear_model), + (jnp.zeros((3, 3)), quadratic_model), + (jnp.zeros((5, 5)), quadratic_model), +] + + +@pytest.mark.parametrize("setting", SETTINGS) +@pytest.mark.parametrize( + "divergence", + [ + dfd.discrete_fisher_divergence, + dfd.besag_pseudolikelihood_mean, + dfd.besag_pseudolikelihood_sum, + ], +) +def test_quasidivergence_smoke_test(setting, divergence, n_points: int = 2): + theta, model = setting + G = theta.shape[0] + + def logq(y): + return model(theta, y) + + ys = jnp.zeros((n_points, G), dtype=int) + + value = divergence(logq, ys) + assert value.shape == ()