Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Gelman Rubin Statistics to MARK02 #163

Merged
merged 14 commits into from
Aug 31, 2023
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pydantic = "^1.10.7"
cyvcf2 = "^0.30.20"
tqdm = "^4.65.0"
networkx = "^3.0"
arviz = "^0.16.1"


[tool.poetry.group.dev.dependencies]
Expand Down
10 changes: 9 additions & 1 deletion src/pyggdrasil/analyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@

from pyggdrasil.analyze._metrics import Metrics

__all__ = ["to_pure_mcmc_data", "check_run_for_tree", "analyze_mcmc_run", "Metrics"]
from pyggdrasil.analyze._rhat import rhats

__all__ = [
"to_pure_mcmc_data",
"check_run_for_tree",
"analyze_mcmc_run",
"Metrics",
"rhats",
]
73 changes: 73 additions & 0 deletions src/pyggdrasil/analyze/_rhat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Implements Gelman-Rubin convergence diagnostics - R-hat."""

import arviz as az
import xarray as xr
import numpy as np


def truncate_arrays(arrays: np.ndarray, length: int) -> np.ndarray:
"""Truncate arrays to given length.

Args:
arrays: array of arrays to truncate
length: length to truncate arrays to

Returns:
truncated arrays
"""
truncated_arrays = [arr[:length] for arr in arrays]

return np.array(truncated_arrays)


def rhats(chains: np.ndarray) -> np.ndarray:
"""Compute estimate of rank normalized split R-hat for a set of chains.

Sometimes referred to as the potential scale reduction factor /
Gelman-Rubin statistic.

Used the “rank” method recommended by Vehtari et al. (2019)

The rank normalized R-hat diagnostic tests for lack of convergence by
comparing the variance between multiple chains to the variance within
each chain.
If convergence has been achieved, the between-chain and within-chain
variances should be identical.

Args:
chains: array of arrays to calculate R-hat for
(minimum of 2 chains, minimum of 4 draws)

Returns:
R-hat for given chains from index 4 to length,
returns list that is 4 shorter than the length of the chains

Note:
- May return NaN if the chains are too short and all values are the same
"""

# minimal length of chains
min_length = 4

# Generate all possible truncation lengths
max_length = min(len(array) for array in chains)
truncation_lengths = range(min_length, max_length + 1)

# Truncate arrays to all possible lengths
truncated_chains = [
truncate_arrays(chains, length) for length in truncation_lengths
]

# make sure that the arrays are in the correct format
truncated_chains = [az.convert_to_dataset(arr) for arr in truncated_chains]

# Calculate R-hat for all possible truncation lengths
rhats = [az.rhat(az.convert_to_dataset(arr)) for arr in truncated_chains]

# Return R-hat for all possible truncation lengths
combined_dataset = xr.concat(rhats, dim="") # type: ignore

# Convert the combined dataset to a NumPy array
rhats = combined_dataset["x"].to_series().to_numpy()
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved

return rhats
4 changes: 4 additions & 0 deletions src/pyggdrasil/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
save_metric_iteration,
save_log_p_iteration,
save_top_trees_plots,
save_rhat_iteration,
save_rhat_iteration_AD_DL,
)

__all__ = [
Expand All @@ -20,4 +22,6 @@
"plot_tree_mcmc_sample",
"plot_tree_no_print",
"save_top_trees_plots",
"save_rhat_iteration",
"save_rhat_iteration_AD_DL",
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved
]
77 changes: 77 additions & 0 deletions src/pyggdrasil/visualize/_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,80 @@ def save_top_trees_plots(data: PureMcmcData, output_dir: Path) -> None:

with open(output_dir / "top_tree_info.json", "w") as f:
json.dump(info, f)


def save_rhat_iteration(
iteration: list[int],
rhats: list[float],
out_fp: Path,
) -> None:
"""Save plot of rhat vs iteration number to disk.

Args:
iteration: list[int]
Iteration numbers.
out_fp: Path
Output file path.
rhats: ndarray
R hat values for each iteration.
"""

# make matplotlib figure, given the axes

fig, ax = plt.subplots()
ax.set_xlabel("Iteration") # type: ignore
# get name of distance measure
ax.set_ylabel(r"$\hat{R}$") # type: ignore
ax.plot(iteration, rhats, color="black") # type: ignore
# specifying horizontal line type
# see limits https://arxiv.org/pdf/1903.08008.pdf
plt.axhline(y=1.1, color="b", linestyle="--", linewidth=0.5) # type: ignore
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved
plt.axhline(y=1.01, color="r", linestyle="-", linewidth=0.5) # type: ignore
ax.tick_params(axis="y", labelcolor="black") # type: ignore
# ensure the output directory exists
# strip the filename from the output path
output_dir = out_fp.parent
output_dir.mkdir(parents=True, exist_ok=True)
# save the figure
fig.savefig(out_fp, format="svg") # type: ignore


def save_rhat_iteration_AD_DL(
iteration: list[int],
rhats_AD: list[float],
rhats_DL: list[float],
out_fp: Path,
) -> None:
"""Save plot of rhat vs iteration number to disk.

Args:
iteration: list[int]
Iteration numbers.
out_fp: Path
Output file path.
rhats_AD: ndarray
R hat values for each iteration for AD.
rhats_DL: ndarray
R hat values for each iteration for DL.
"""

# make matplotlib figure, given the axes

fig, ax = plt.subplots()
ax.set_xlabel("Iteration") # type: ignore
# get name of distance measure
ax.set_ylabel(r"$\hat{R}$") # type: ignore
ax.plot(iteration, rhats_AD, color="darkgreen", label="AD") # type: ignore
ax.plot(iteration, rhats_DL, color="darkorange", label="DL") # type: ignore
# specifying horizontal line type
# see limits https://arxiv.org/pdf/1903.08008.pdf
plt.axhline(y=1.1, color="b", linestyle="--", linewidth=0.5) # type: ignore
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved
plt.axhline(y=1.01, color="r", linestyle="-", linewidth=0.5) # type: ignore
ax.tick_params(axis="y", labelcolor="black") # type: ignore
ax.legend(loc="upper right") # type: ignore
# ensure the output directory exists
# strip the filename from the output path
output_dir = out_fp.parent
output_dir.mkdir(parents=True, exist_ok=True)
# save the figure
fig.savefig(out_fp, format="svg") # type: ignore
100 changes: 100 additions & 0 deletions tests/analyze/test_rhat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Tests the Gelman-Rubin convergence diagnostics."""


import numpy as np

import pyggdrasil.analyze._rhat as rhat


def test_rhat_basic():
"""Tests the return shape and one of case of R-hat."""
# given two chains that converge to the same value 8
chains = np.array(
[
np.array(
[
1,
1,
3,
5,
6,
6,
6,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
]
),
np.array(
[
5,
2,
3,
4,
5,
6,
7,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
8,
]
),
]
)
# calculate the rhat
rhats = rhat.rhats(chains)

# Does it return the correct number of rhats? 4 less than the length of the chains
chains_len = len(chains[0])
assert rhats.shape == (chains_len - 3,)

# Does it approach 1 as the number of draws increases?
assert rhats[-1] < 1.3 # 1.2 is often used as a cutoff for convergence
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved
46 changes: 46 additions & 0 deletions workflows/analyze.smk
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,49 @@ rule true_trees_found:
f.write(f"True trees found in iterations:")
for i in true_iterations:
f.write(f"\n{i}")


rule calculate_rhats_4chains:
"""Calculate Rhat for 4 chains, each with a different init tree and seed but same data and config."""
input:
mcmc_metric_samples1="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed1,\d+}-{mutation_data_id}-i{init_tree_id1}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
mcmc_metric_samples2="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed2,\d+}-{mutation_data_id}-i{init_tree_id2}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
mcmc_metric_samples3="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed3,\d+}-{mutation_data_id}-i{init_tree_id3}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
mcmc_metric_samples4="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed4,\d+}-{mutation_data_id}-i{init_tree_id4}-{mcmc_config_id}/{base_tree_id}/{metric}.json",
wildcard_constraints:
mutation_data_id = "CS.*",
mcmc_config_id= "MC_(?:(?!/).)+",
output:
result="{DATADIR}/{experiment}/analysis/rhat/{base_tree_id}/{metric}/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-{mutation_data_id}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}-{mcmc_config_id}/rhat.json",
run:
import json
import numpy as np
# load the data
# chain 1
# load data
fp = Path(input.mcmc_metric_samples1)
iteration, result1 = yg.serialize.read_metric_result(fp)
result1 = np.array(result1)
# chain 2
fp = Path(input.mcmc_metric_samples2)
_, result2 = yg.serialize.read_metric_result(fp)
result2 = np.array(result2)
# chain 3
fp = Path(input.mcmc_metric_samples3)
_, result3 = yg.serialize.read_metric_result(fp)
result3 = np.array(result3)
# chain 4
fp = Path(input.mcmc_metric_samples4)
_, result4 = yg.serialize.read_metric_result(fp)
result4 = np.array(result4)

# calculate rhat - returns the 4-length array of rhats
chains = np.array([result1, result2, result3, result4])
rhat = yg.analyze.rhats(chains)

# write the result
fp = Path(output.result)
# save with iteration numbers, struncate the first 3 iterations
iteration = iteration[3:]
yg.serialize.save_metric_result(iteration, list(rhat), fp)

Loading
Loading