-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding Gelman Rubin Statistics to MARK02 (#163)
* setup mark02 for rerun * fix paths * untested legend outside of plot * add avriz * add docstring * make rhats part of API * make calculate rhat rule * fix rule rhat * Rhatplot * add multiple AD DL * integrated mark02 rhat * adjust to euler
- Loading branch information
1 parent
506728c
commit 9d1d1a8
Showing
9 changed files
with
408 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
"""Implements Gelman-Rubin convergence diagnostics - R-hat.""" | ||
|
||
import arviz as az | ||
import xarray as xr | ||
import numpy as np | ||
|
||
|
||
def truncate_arrays(arrays: np.ndarray, length: int) -> np.ndarray: | ||
"""Truncate arrays to given length. | ||
Args: | ||
arrays: array of arrays to truncate | ||
length: length to truncate arrays to | ||
Returns: | ||
truncated arrays | ||
""" | ||
truncated_arrays = [arr[:length] for arr in arrays] | ||
|
||
return np.array(truncated_arrays) | ||
|
||
|
||
def rhats(chains: np.ndarray) -> np.ndarray: | ||
"""Compute estimate of rank normalized split R-hat for a set of chains. | ||
Sometimes referred to as the potential scale reduction factor / | ||
Gelman-Rubin statistic. | ||
Used the “rank” method recommended by Vehtari et al. (2019) | ||
The rank normalized R-hat diagnostic tests for lack of convergence by | ||
comparing the variance between multiple chains to the variance within | ||
each chain. | ||
If convergence has been achieved, the between-chain and within-chain | ||
variances should be identical. | ||
Args: | ||
chains: array of arrays to calculate R-hat for | ||
(minimum of 2 chains, minimum of 4 draws) | ||
Returns: | ||
R-hat for given chains from index 4 to length, | ||
returns list that is 4 shorter than the length of the chains | ||
Note: | ||
- May return NaN if the chains are too short and all values are the same | ||
""" | ||
|
||
# minimal length of chains | ||
min_length = 4 | ||
|
||
# Generate all possible truncation lengths | ||
max_length = min(len(array) for array in chains) | ||
truncation_lengths = range(min_length, max_length + 1) | ||
|
||
# Truncate arrays to all possible lengths | ||
truncated_chains = [ | ||
truncate_arrays(chains, length) for length in truncation_lengths | ||
] | ||
|
||
# make sure that the arrays are in the correct format | ||
truncated_chains = [az.convert_to_dataset(arr) for arr in truncated_chains] | ||
|
||
# Calculate R-hat for all possible truncation lengths | ||
rhats = [az.rhat(az.convert_to_dataset(arr)) for arr in truncated_chains] | ||
|
||
# Return R-hat for all possible truncation lengths | ||
combined_dataset = xr.concat(rhats, dim="") # type: ignore | ||
|
||
# Convert the combined dataset to a NumPy array | ||
rhats = combined_dataset["x"].to_series().to_numpy() | ||
|
||
return rhats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
"""Tests the Gelman-Rubin convergence diagnostics.""" | ||
|
||
|
||
import numpy as np | ||
|
||
import pyggdrasil.analyze._rhat as rhat | ||
|
||
|
||
def test_rhat_basic(): | ||
"""Tests the return shape and one of case of R-hat.""" | ||
# given two chains that converge to the same value 8 | ||
chains = np.array( | ||
[ | ||
np.array( | ||
[ | ||
1, | ||
1, | ||
3, | ||
5, | ||
6, | ||
6, | ||
6, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
] | ||
), | ||
np.array( | ||
[ | ||
5, | ||
2, | ||
3, | ||
4, | ||
5, | ||
6, | ||
7, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
8, | ||
] | ||
), | ||
] | ||
) | ||
# calculate the rhat | ||
rhats = rhat.rhats(chains) | ||
|
||
# Does it return the correct number of rhats? 4 less than the length of the chains | ||
chains_len = len(chains[0]) | ||
assert rhats.shape == (chains_len - 3,) | ||
|
||
# Does it approach 1 as the number of draws increases? | ||
assert rhats[-1] < 1.3 # 1.2 is often used as a cutoff for convergence |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.