Skip to content

Commit

Permalink
Adding Gelman Rubin Statistics to MARK02 (#163)
Browse files Browse the repository at this point in the history
* setup mark02 for rerun

* fix paths

* untested legend outside of plot

* add avriz

* add docstring

* make rhats part of API

* make calculate rhat rule

* fix rule rhat

* Rhatplot

* add multiple AD DL

* integrated mark02 rhat

* adjust to euler
  • Loading branch information
gordonkoehn authored Aug 31, 2023
1 parent 506728c commit 9d1d1a8
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pydantic = "^1.10.7"
cyvcf2 = "^0.30.20"
tqdm = "^4.65.0"
networkx = "^3.0"
arviz = "^0.16.1"


[tool.poetry.group.dev.dependencies]
Expand Down
10 changes: 9 additions & 1 deletion src/pyggdrasil/analyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@

from pyggdrasil.analyze._metrics import Metrics

__all__ = ["to_pure_mcmc_data", "check_run_for_tree", "analyze_mcmc_run", "Metrics"]
from pyggdrasil.analyze._rhat import rhats

__all__ = [
"to_pure_mcmc_data",
"check_run_for_tree",
"analyze_mcmc_run",
"Metrics",
"rhats",
]
73 changes: 73 additions & 0 deletions src/pyggdrasil/analyze/_rhat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Implements Gelman-Rubin convergence diagnostics - R-hat."""

import arviz as az
import xarray as xr
import numpy as np


def truncate_arrays(arrays: np.ndarray, length: int) -> np.ndarray:
"""Truncate arrays to given length.
Args:
arrays: array of arrays to truncate
length: length to truncate arrays to
Returns:
truncated arrays
"""
truncated_arrays = [arr[:length] for arr in arrays]

return np.array(truncated_arrays)


def rhats(chains: np.ndarray) -> np.ndarray:
"""Compute estimate of rank normalized split R-hat for a set of chains.
Sometimes referred to as the potential scale reduction factor /
Gelman-Rubin statistic.
Used the “rank” method recommended by Vehtari et al. (2019)
The rank normalized R-hat diagnostic tests for lack of convergence by
comparing the variance between multiple chains to the variance within
each chain.
If convergence has been achieved, the between-chain and within-chain
variances should be identical.
Args:
chains: array of arrays to calculate R-hat for
(minimum of 2 chains, minimum of 4 draws)
Returns:
R-hat for given chains from index 4 to length,
returns list that is 4 shorter than the length of the chains
Note:
- May return NaN if the chains are too short and all values are the same
"""

# minimal length of chains
min_length = 4

# Generate all possible truncation lengths
max_length = min(len(array) for array in chains)
truncation_lengths = range(min_length, max_length + 1)

# Truncate arrays to all possible lengths
truncated_chains = [
truncate_arrays(chains, length) for length in truncation_lengths
]

# make sure that the arrays are in the correct format
truncated_chains = [az.convert_to_dataset(arr) for arr in truncated_chains]

# Calculate R-hat for all possible truncation lengths
rhats = [az.rhat(az.convert_to_dataset(arr)) for arr in truncated_chains]

# Return R-hat for all possible truncation lengths
combined_dataset = xr.concat(rhats, dim="") # type: ignore

# Convert the combined dataset to a NumPy array
rhats = combined_dataset["x"].to_series().to_numpy()

return rhats
4 changes: 4 additions & 0 deletions src/pyggdrasil/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
save_metric_iteration,
save_log_p_iteration,
save_top_trees_plots,
save_rhat_iteration,
save_rhat_iteration_AD_DL,
)

__all__ = [
Expand All @@ -20,4 +22,6 @@
"plot_tree_mcmc_sample",
"plot_tree_no_print",
"save_top_trees_plots",
"save_rhat_iteration",
"save_rhat_iteration_AD_DL",
]
77 changes: 77 additions & 0 deletions src/pyggdrasil/visualize/_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,80 @@ def save_top_trees_plots(data: PureMcmcData, output_dir: Path) -> None:

with open(output_dir / "top_tree_info.json", "w") as f:
json.dump(info, f)


def save_rhat_iteration(
iteration: list[int],
rhats: list[float],
out_fp: Path,
) -> None:
"""Save plot of rhat vs iteration number to disk.
Args:
iteration: list[int]
Iteration numbers.
out_fp: Path
Output file path.
rhats: ndarray
R hat values for each iteration.
"""

# make matplotlib figure, given the axes

fig, ax = plt.subplots()
ax.set_xlabel("Iteration") # type: ignore
# get name of distance measure
ax.set_ylabel(r"$\hat{R}$") # type: ignore
ax.plot(iteration, rhats, color="black") # type: ignore
# specifying horizontal line type
# see limits https://arxiv.org/pdf/1903.08008.pdf
plt.axhline(y=1.1, color="b", linestyle="--", linewidth=0.5) # type: ignore
plt.axhline(y=1.01, color="r", linestyle="-", linewidth=0.5) # type: ignore
ax.tick_params(axis="y", labelcolor="black") # type: ignore
# ensure the output directory exists
# strip the filename from the output path
output_dir = out_fp.parent
output_dir.mkdir(parents=True, exist_ok=True)
# save the figure
fig.savefig(out_fp, format="svg") # type: ignore


def save_rhat_iteration_AD_DL(
iteration: list[int],
rhats_AD: list[float],
rhats_DL: list[float],
out_fp: Path,
) -> None:
"""Save plot of rhat vs iteration number to disk.
Args:
iteration: list[int]
Iteration numbers.
out_fp: Path
Output file path.
rhats_AD: ndarray
R hat values for each iteration for AD.
rhats_DL: ndarray
R hat values for each iteration for DL.
"""

# make matplotlib figure, given the axes

fig, ax = plt.subplots()
ax.set_xlabel("Iteration") # type: ignore
# get name of distance measure
ax.set_ylabel(r"$\hat{R}$") # type: ignore
ax.plot(iteration, rhats_AD, color="darkgreen", label="AD") # type: ignore
ax.plot(iteration, rhats_DL, color="darkorange", label="DL") # type: ignore
# specifying horizontal line type
# see limits https://arxiv.org/pdf/1903.08008.pdf
plt.axhline(y=1.1, color="b", linestyle="--", linewidth=0.5) # type: ignore
plt.axhline(y=1.01, color="r", linestyle="-", linewidth=0.5) # type: ignore
ax.tick_params(axis="y", labelcolor="black") # type: ignore
ax.legend(loc="upper right") # type: ignore
# ensure the output directory exists
# strip the filename from the output path
output_dir = out_fp.parent
output_dir.mkdir(parents=True, exist_ok=True)
# save the figure
fig.savefig(out_fp, format="svg") # type: ignore
100 changes: 100 additions & 0 deletions tests/analyze/test_rhat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Tests the Gelman-Rubin convergence diagnostics."""


import numpy as np

import pyggdrasil.analyze._rhat as rhat


def test_rhat_basic():
"""Tests the return shape and one of case of R-hat."""
# given two chains that converge to the same value 8
chains = np.array(
[
np.array(
[
1,
1,
3,
5,
6,
6,
6,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
]
),
np.array(
[
5,
2,
3,
4,
5,
6,
7,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
]
),
]
)
# calculate the rhat
rhats = rhat.rhats(chains)

# Does it return the correct number of rhats? 4 less than the length of the chains
chains_len = len(chains[0])
assert rhats.shape == (chains_len - 3,)

# Does it approach 1 as the number of draws increases?
assert rhats[-1] < 1.3 # 1.2 is often used as a cutoff for convergence
46 changes: 46 additions & 0 deletions workflows/analyze.smk
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,49 @@ rule true_trees_found:
f.write(f"True trees found in iterations:")
for i in true_iterations:
f.write(f"\n{i}")


rule calculate_rhats_4chains:
"""Calculate Rhat for 4 chains, each with a different init tree and seed but same data and config."""
input:
mcmc_metric_samples1="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed1,\d+}-{mutation_data_id}-i{init_tree_id1}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
mcmc_metric_samples2="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed2,\d+}-{mutation_data_id}-i{init_tree_id2}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
mcmc_metric_samples3="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed3,\d+}-{mutation_data_id}-i{init_tree_id3}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
mcmc_metric_samples4="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed4,\d+}-{mutation_data_id}-i{init_tree_id4}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
wildcard_constraints:
mutation_data_id = "CS.*",
mcmc_config_id= "MC_(?:(?!/).)+",
output:
result="{DATADIR}/{experiment}/analysis/rhat/{base_tree_id}/{metric}/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-{mutation_data_id}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}-{mcmc_config_id}/rhat.json",
run:
import json
import numpy as np
# load the data
# chain 1
# load data
fp = Path(input.mcmc_metric_samples1)
iteration, result1 = yg.serialize.read_metric_result(fp)
result1 = np.array(result1)
# chain 2
fp = Path(input.mcmc_metric_samples2)
_, result2 = yg.serialize.read_metric_result(fp)
result2 = np.array(result2)
# chain 3
fp = Path(input.mcmc_metric_samples3)
_, result3 = yg.serialize.read_metric_result(fp)
result3 = np.array(result3)
# chain 4
fp = Path(input.mcmc_metric_samples4)
_, result4 = yg.serialize.read_metric_result(fp)
result4 = np.array(result4)

# calculate rhat - returns the 4-length array of rhats
chains = np.array([result1, result2, result3, result4])
rhat = yg.analyze.rhats(chains)

# write the result
fp = Path(output.result)
# save with iteration numbers, struncate the first 3 iterations
iteration = iteration[3:]
yg.serialize.save_metric_result(iteration, list(rhat), fp)

Loading

0 comments on commit 9d1d1a8

Please sign in to comment.