Skip to content

Commit

Permalink
Utilities for graphical posterior predictive checking (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz authored Oct 22, 2024
1 parent 1d7e25a commit f9569b1
Show file tree
Hide file tree
Showing 6 changed files with 612 additions and 12 deletions.
14 changes: 14 additions & 0 deletions src/jnotype/checks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,26 @@
calculate_mcc,
calculate_mutation_frequencies,
calculate_number_of_mutations_histogram,
convert_genotypes_to_integers,
convert_integers_to_genotypes,
calculate_atoms_occurrence,
subsample_pytree,
simulate_summary_statistic,
)
from jnotype.checks._plots import rc_context, rcParams, plot_summary_statistic

__all__ = [
"plot_histograms",
"calculate_quantiles",
"calculate_mcc",
"calculate_mutation_frequencies",
"calculate_number_of_mutations_histogram",
"rc_context",
"rcParams",
"plot_summary_statistic",
"convert_genotypes_to_integers",
"convert_integers_to_genotypes",
"calculate_atoms_occurrence",
"subsample_pytree",
"simulate_summary_statistic",
]
26 changes: 16 additions & 10 deletions src/jnotype/checks/_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import matplotlib.pyplot as plt
import numpy as np

import jax
import jax.numpy as jnp
from jaxtyping import Float, Array


def calculate_quantiles(
samples: np.ndarray,
quantiles: np.ndarray,
) -> np.ndarray:
samples: Float[Array, "n_samples dimension"],
quantiles: Float[Array, " n_quantiles"],
) -> Float[Array, "n_quantiles dimension"]:
"""Calculates quantiles.
Args:
Expand All @@ -20,15 +24,15 @@ def calculate_quantiles(
quantile value for each dimension,
shape (n_quantiles, dimension)
"""
return np.quantile(samples, axis=0, q=quantiles)
return jnp.quantile(samples, axis=0, q=quantiles)


def apply_histogram(
draws: np.ndarray,
draws: Float[Array, "n_datasets n_values"],
bins: Union[int, Sequence[float], np.ndarray],
density: bool,
) -> np.ndarray:
"""Maps `np.histogram` over several vectors of values
"""Maps `jnp.histogram` over several vectors of values
to contruct several histograms.
Args:
Expand All @@ -39,11 +43,13 @@ def apply_histogram(
Returns:
histogram counts, shape (n_samples, bins)
"""
_, bins = np.histogram(draws[0], bins=bins)
_, bins = jnp.histogram(draws[0], bins=bins, density=density)

def f(sample):
"""Auxiliary function used for jax.vmap"""
return jnp.histogram(sample, bins=bins, density=density)[0]

return np.asarray(
[np.histogram(vect, bins=bins, density=density)[0] for vect in draws]
)
return jax.vmap(f)(draws)


def plot_histograms(
Expand Down
Loading

0 comments on commit f9569b1

Please sign in to comment.