diff --git a/pyproject.toml b/pyproject.toml index 91aa31ba..e576a9ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ pydantic = "^1.10.7" cyvcf2 = "^0.30.20" tqdm = "^4.65.0" networkx = "^3.0" +arviz = "^0.16.1" [tool.poetry.group.dev.dependencies] diff --git a/src/pyggdrasil/analyze/__init__.py b/src/pyggdrasil/analyze/__init__.py index c11a1221..10f94428 100644 --- a/src/pyggdrasil/analyze/__init__.py +++ b/src/pyggdrasil/analyze/__init__.py @@ -6,4 +6,12 @@ from pyggdrasil.analyze._metrics import Metrics -__all__ = ["to_pure_mcmc_data", "check_run_for_tree", "analyze_mcmc_run", "Metrics"] +from pyggdrasil.analyze._rhat import rhats + +__all__ = [ + "to_pure_mcmc_data", + "check_run_for_tree", + "analyze_mcmc_run", + "Metrics", + "rhats", +] diff --git a/src/pyggdrasil/analyze/_rhat.py b/src/pyggdrasil/analyze/_rhat.py new file mode 100644 index 00000000..eac330b5 --- /dev/null +++ b/src/pyggdrasil/analyze/_rhat.py @@ -0,0 +1,73 @@ +"""Implements Gelman-Rubin convergence diagnostics - R-hat.""" + +import arviz as az +import xarray as xr +import numpy as np + + +def truncate_arrays(arrays: np.ndarray, length: int) -> np.ndarray: + """Truncate arrays to given length. + + Args: + arrays: array of arrays to truncate + length: length to truncate arrays to + + Returns: + truncated arrays + """ + truncated_arrays = [arr[:length] for arr in arrays] + + return np.array(truncated_arrays) + + +def rhats(chains: np.ndarray) -> np.ndarray: + """Compute estimate of rank normalized split R-hat for a set of chains. + + Sometimes referred to as the potential scale reduction factor / + Gelman-Rubin statistic. + + Used the “rank” method recommended by Vehtari et al. (2019) + + The rank normalized R-hat diagnostic tests for lack of convergence by + comparing the variance between multiple chains to the variance within + each chain. + If convergence has been achieved, the between-chain and within-chain + variances should be identical. + + Args: + chains: array of arrays to calculate R-hat for + (minimum of 2 chains, minimum of 4 draws) + + Returns: + R-hat for given chains from index 4 to length, + returns list that is 4 shorter than the length of the chains + + Note: + - May return NaN if the chains are too short and all values are the same + """ + + # minimal length of chains + min_length = 4 + + # Generate all possible truncation lengths + max_length = min(len(array) for array in chains) + truncation_lengths = range(min_length, max_length + 1) + + # Truncate arrays to all possible lengths + truncated_chains = [ + truncate_arrays(chains, length) for length in truncation_lengths + ] + + # make sure that the arrays are in the correct format + truncated_chains = [az.convert_to_dataset(arr) for arr in truncated_chains] + + # Calculate R-hat for all possible truncation lengths + rhats = [az.rhat(az.convert_to_dataset(arr)) for arr in truncated_chains] + + # Return R-hat for all possible truncation lengths + combined_dataset = xr.concat(rhats, dim="") # type: ignore + + # Convert the combined dataset to a NumPy array + rhats = combined_dataset["x"].to_series().to_numpy() + + return rhats diff --git a/src/pyggdrasil/visualize/__init__.py b/src/pyggdrasil/visualize/__init__.py index 40ca8c96..15931926 100644 --- a/src/pyggdrasil/visualize/__init__.py +++ b/src/pyggdrasil/visualize/__init__.py @@ -11,6 +11,8 @@ save_metric_iteration, save_log_p_iteration, save_top_trees_plots, + save_rhat_iteration, + save_rhat_iteration_AD_DL, ) __all__ = [ @@ -20,4 +22,6 @@ "plot_tree_mcmc_sample", "plot_tree_no_print", "save_top_trees_plots", + "save_rhat_iteration", + "save_rhat_iteration_AD_DL", ] diff --git a/src/pyggdrasil/visualize/_mcmc.py b/src/pyggdrasil/visualize/_mcmc.py index 75cee30b..013ad981 100644 --- a/src/pyggdrasil/visualize/_mcmc.py +++ b/src/pyggdrasil/visualize/_mcmc.py @@ -160,3 +160,80 @@ def save_top_trees_plots(data: PureMcmcData, output_dir: Path) -> None: with open(output_dir / "top_tree_info.json", "w") as f: json.dump(info, f) + + +def save_rhat_iteration( + iteration: list[int], + rhats: list[float], + out_fp: Path, +) -> None: + """Save plot of rhat vs iteration number to disk. + + Args: + iteration: list[int] + Iteration numbers. + out_fp: Path + Output file path. + rhats: ndarray + R hat values for each iteration. + """ + + # make matplotlib figure, given the axes + + fig, ax = plt.subplots() + ax.set_xlabel("Iteration") # type: ignore + # get name of distance measure + ax.set_ylabel(r"$\hat{R}$") # type: ignore + ax.plot(iteration, rhats, color="black") # type: ignore + # specifying horizontal line type + # see limits https://arxiv.org/pdf/1903.08008.pdf + plt.axhline(y=1.1, color="b", linestyle="--", linewidth=0.5) # type: ignore + plt.axhline(y=1.01, color="r", linestyle="-", linewidth=0.5) # type: ignore + ax.tick_params(axis="y", labelcolor="black") # type: ignore + # ensure the output directory exists + # strip the filename from the output path + output_dir = out_fp.parent + output_dir.mkdir(parents=True, exist_ok=True) + # save the figure + fig.savefig(out_fp, format="svg") # type: ignore + + +def save_rhat_iteration_AD_DL( + iteration: list[int], + rhats_AD: list[float], + rhats_DL: list[float], + out_fp: Path, +) -> None: + """Save plot of rhat vs iteration number to disk. + + Args: + iteration: list[int] + Iteration numbers. + out_fp: Path + Output file path. + rhats_AD: ndarray + R hat values for each iteration for AD. + rhats_DL: ndarray + R hat values for each iteration for DL. + """ + + # make matplotlib figure, given the axes + + fig, ax = plt.subplots() + ax.set_xlabel("Iteration") # type: ignore + # get name of distance measure + ax.set_ylabel(r"$\hat{R}$") # type: ignore + ax.plot(iteration, rhats_AD, color="darkgreen", label="AD") # type: ignore + ax.plot(iteration, rhats_DL, color="darkorange", label="DL") # type: ignore + # specifying horizontal line type + # see limits https://arxiv.org/pdf/1903.08008.pdf + plt.axhline(y=1.1, color="b", linestyle="--", linewidth=0.5) # type: ignore + plt.axhline(y=1.01, color="r", linestyle="-", linewidth=0.5) # type: ignore + ax.tick_params(axis="y", labelcolor="black") # type: ignore + ax.legend(loc="upper right") # type: ignore + # ensure the output directory exists + # strip the filename from the output path + output_dir = out_fp.parent + output_dir.mkdir(parents=True, exist_ok=True) + # save the figure + fig.savefig(out_fp, format="svg") # type: ignore diff --git a/tests/analyze/test_rhat.py b/tests/analyze/test_rhat.py new file mode 100644 index 00000000..55b5905b --- /dev/null +++ b/tests/analyze/test_rhat.py @@ -0,0 +1,100 @@ +"""Tests the Gelman-Rubin convergence diagnostics.""" + + +import numpy as np + +import pyggdrasil.analyze._rhat as rhat + + +def test_rhat_basic(): + """Tests the return shape and one of case of R-hat.""" + # given two chains that converge to the same value 8 + chains = np.array( + [ + np.array( + [ + 1, + 1, + 3, + 5, + 6, + 6, + 6, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + ] + ), + np.array( + [ + 5, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + 8, + ] + ), + ] + ) + # calculate the rhat + rhats = rhat.rhats(chains) + + # Does it return the correct number of rhats? 4 less than the length of the chains + chains_len = len(chains[0]) + assert rhats.shape == (chains_len - 3,) + + # Does it approach 1 as the number of draws increases? + assert rhats[-1] < 1.3 # 1.2 is often used as a cutoff for convergence diff --git a/workflows/analyze.smk b/workflows/analyze.smk index 01cc0bd9..9d59af68 100644 --- a/workflows/analyze.smk +++ b/workflows/analyze.smk @@ -100,3 +100,49 @@ rule true_trees_found: f.write(f"True trees found in iterations:") for i in true_iterations: f.write(f"\n{i}") + + +rule calculate_rhats_4chains: + """Calculate Rhat for 4 chains, each with a different init tree and seed but same data and config.""" + input: + mcmc_metric_samples1="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed1,\d+}-{mutation_data_id}-i{init_tree_id1}-{mcmc_config_id}/{base_tree_id}/{metric}.json", + mcmc_metric_samples2="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed2,\d+}-{mutation_data_id}-i{init_tree_id2}-{mcmc_config_id}/{base_tree_id}/{metric}.json", + mcmc_metric_samples3="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed3,\d+}-{mutation_data_id}-i{init_tree_id3}-{mcmc_config_id}/{base_tree_id}/{metric}.json", + mcmc_metric_samples4="{DATADIR}/{experiment}/analysis/MCMC_{mcmc_seed4,\d+}-{mutation_data_id}-i{init_tree_id4}-{mcmc_config_id}/{base_tree_id}/{metric}.json", + wildcard_constraints: + mutation_data_id = "CS.*", + mcmc_config_id= "MC_(?:(?!/).)+", + output: + result="{DATADIR}/{experiment}/analysis/rhat/{base_tree_id}/{metric}/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-{mutation_data_id}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}-{mcmc_config_id}/rhat.json", + run: + import json + import numpy as np + # load the data + # chain 1 + # load data + fp = Path(input.mcmc_metric_samples1) + iteration, result1 = yg.serialize.read_metric_result(fp) + result1 = np.array(result1) + # chain 2 + fp = Path(input.mcmc_metric_samples2) + _, result2 = yg.serialize.read_metric_result(fp) + result2 = np.array(result2) + # chain 3 + fp = Path(input.mcmc_metric_samples3) + _, result3 = yg.serialize.read_metric_result(fp) + result3 = np.array(result3) + # chain 4 + fp = Path(input.mcmc_metric_samples4) + _, result4 = yg.serialize.read_metric_result(fp) + result4 = np.array(result4) + + # calculate rhat - returns the 4-length array of rhats + chains = np.array([result1, result2, result3, result4]) + rhat = yg.analyze.rhats(chains) + + # write the result + fp = Path(output.result) + # save with iteration numbers, struncate the first 3 iterations + iteration = iteration[3:] + yg.serialize.save_metric_result(iteration, list(rhat), fp) + diff --git a/workflows/mark02.smk b/workflows/mark02.smk index a5c9f02a..a15807a9 100644 --- a/workflows/mark02.smk +++ b/workflows/mark02.smk @@ -16,14 +16,14 @@ from pyggdrasil.tree_inference import CellSimulationId, TreeType, TreeId, McmcCo ##################### # Environment variables -#DATADIR = "../data" +#DATADIR = "../data.nosync" DATADIR = "/cluster/work/bewi/members/gkoehn/data" ##################### experiment="mark02" # Metrics: Distances / Similarity Measure to use -metrics = ["MP3", "AD", "DL"] # also AD <-- configure distances here +metrics = ["AD", "DL"] # also MP3 <-- configure distances here ##################### # Error Parameters @@ -135,6 +135,60 @@ def make_all_mark02(): filepath + mc + "/" + str(cs) + "/" + str(true_tree_id) + "/" + each_metric + ".svg" ) + # make combined AD_DL convergence paths + #../data.nosync/mark02/plots/MC_0.1_0.1_2000_0_1-MPC_0.1_0.65_0.25/CS_42-T_r_10_34-1000_0.1_0.1_0.0_f_UXR/T_r_10_34/AD_DL/rhat4-MCMCseeds_s42_s34_s12_s79-iTrees_iT_r_10_31_iT_r_10_32_iT_r_10_12_iT_r_10_89/rhat.svg + # make cell simulation ids + #cell_simulation_id_ls = [] + for true_tree_id in tree_id_ls: + for n_cell in n_cells: + for error_name, error in errors.items(): + #cell_simulation_id_ls.append( + cs = CellSimulationId( + seed=CS_seed, + tree_id=true_tree_id, + n_cells=n_cell, + fpr=error["fpr"], + fnr=error["fnr"], + na_rate=rate_na, + observe_homozygous = observe_homozygous, + strategy=cell_attachment_strategy + ) + #) + + mc = McmcConfig( + n_samples=n_samples, + fpr=error["fpr"], + fnr=error["fnr"] + ).id() + + mcmc_seeds = [point[0] for point in initial_points] + nodes = true_tree_id.n_nodes + init_trees_ids = [] + for point in initial_points: + init_tree_type = point[1] + init_tree_seed = point[2] + init_tree_id = TreeId(TreeType(init_tree_type), nodes, init_tree_seed) + init_trees_ids.append(init_tree_id) + + + # make combined AD_DL convergence paths + filepaths.append( + filepath + mc + "/" + str(cs) + "/" + str(true_tree_id) + "/AD_DL/rhat4-MCMCseeds_s" + str(mcmc_seeds[0]) +"_s"+ str(mcmc_seeds[1]) +"_s"+ str(mcmc_seeds[2]) +"_s"+ str(mcmc_seeds[3]) +"-iTrees_i"+str(init_trees_ids[0])+"_i"+ str(init_trees_ids[1])+"_i"+ str(init_trees_ids[2]) + "_i"+ str(init_trees_ids[3]) + "/rhat.svg" + ) + filepaths.append( + filepath + mc + "/" + str(cs) + "/" + str(true_tree_id) + "/AD/rhat4-MCMCseeds_s" + str( + mcmc_seeds[0]) + "_s" + str(mcmc_seeds[1]) + "_s" + str(mcmc_seeds[2]) + "_s" + str( + mcmc_seeds[3]) + "-iTrees_i" + str(init_trees_ids[0]) + "_i" + str( + init_trees_ids[1]) + "_i" + str(init_trees_ids[2]) + "_i" + str(init_trees_ids[3]) + "/rhat.svg" + ) + filepaths.append( + filepath + mc + "/" + str(cs) + "/" + str(true_tree_id) + "/DL/rhat4-MCMCseeds_s" + str( + mcmc_seeds[0]) + "_s" + str(mcmc_seeds[1]) + "_s" + str(mcmc_seeds[2]) + "_s" + str( + mcmc_seeds[3]) + "-iTrees_i" + str(init_trees_ids[0]) + "_i" + str( + init_trees_ids[1]) + "_i" + str(init_trees_ids[2]) + "_i" + str(init_trees_ids[3]) + "/rhat.svg" + ) + print(filepaths[-1]) + return filepaths @@ -242,17 +296,17 @@ rule combined_chain_histogram: sublist = [float(x) for x in sublist] # Create histogram for the sublist with the color and label - ax.hist(sublist,bins='auto', range = (0,1),alpha=0.5,color=colors[color_index],label=labels[i]) + ax.hist(sublist,bins=100, range = (0,1),alpha=0.5,color=colors[color_index],label=labels[i]) # Set labels and title ax.set_xlabel(f"Similarity: {wildcards.metric}") ax.set_ylabel('Frequency') # Add a legend - ax.legend() + ax.legend(bbox_to_anchor = (1.04, 0.5), loc = "center left", borderaxespad = 0) # save the histogram - fig.savefig(Path(output.combined_chain_histogram)) + fig.savefig(Path(output.combined_chain_histogram), bbox_inches="tight") diff --git a/workflows/tree_inference.smk b/workflows/tree_inference.smk index 918ee7f7..e0eaa445 100644 --- a/workflows/tree_inference.smk +++ b/workflows/tree_inference.smk @@ -19,9 +19,9 @@ from pyggdrasil.tree_inference import ( ############################################### ## Relative path from DATADIR to the repo root -#REPODIR = "/cluster/work/bewi/members/gkoehn/repos/PYggdrasil" -REPODIR = ".." -#DATADIR = "/cluster/work/bewi/members/gkoehn/data" +REPODIR = "/cluster/work/bewi/members/gkoehn/repos/PYggdrasil" +#REPODIR = ".." +DATADIR = "/cluster/work/bewi/members/gkoehn/data" ############################################### diff --git a/workflows/visualize.smk b/workflows/visualize.smk index a6522553..a59483b9 100644 --- a/workflows/visualize.smk +++ b/workflows/visualize.smk @@ -140,3 +140,42 @@ rule plot_tree_relabeled: yg.visualize.plot_tree( true_tree, save_name=out_fp.name.__str__(), save_dir=out_fp.parent, print_options=print_options, rename_labels=mapping_dict ) + + +rule plot_rhat: + """Plot the rhat values for each parameter over iterations of an mcmc run""" + input: + rhat="{DATADIR}/{experiment}/analysis/rhat/{base_tree_id}/{metric}/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-{mutation_data_id}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}-{mcmc_config_id}/rhat.json", + output: + plot="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/{base_tree_id}/{metric}/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}/rhat.svg", + wildcard_constraints: + # metric cannot be "AD_DL" because that is a special case + metric = r"(?!(AD_DL))\w+" + run: + in_fp = Path(input.rhat) + with open(in_fp) as f: + data = json.load(f) + out_fp = Path(output.plot) + yg.visualize.save_rhat_iteration(data["iteration"], data["result"], out_fp=out_fp) + + +rule plot_rhat_AD_DL: + """Plot the rhat values for each parameter over iterations of an mcmc run""" + input: + rhat_AD="{DATADIR}/{experiment}/analysis/rhat/{base_tree_id}/AD/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-{mutation_data_id}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}-{mcmc_config_id}/rhat.json", + rhat_DL="{DATADIR}/{experiment}/analysis/rhat/{base_tree_id}/DL/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-{mutation_data_id}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}-{mcmc_config_id}/rhat.json", + + output: + plot="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/{base_tree_id}/AD_DL/rhat4-MCMCseeds_s{mcmc_seed1}_s{mcmc_seed2}_s{mcmc_seed3}_s{mcmc_seed4}-iTrees_i{init_tree_id1}_i{init_tree_id2}_i{init_tree_id3}_i{init_tree_id4}/rhat.svg", + run: + # load AD + in_fp = Path(input.rhat_AD) + with open(in_fp) as f: + data_AD = json.load(f) + # load DL + in_fp = Path(input.rhat_DL) + with open(in_fp) as f: + data_DL = json.load(f) + # make output path + out_fp = Path(output.plot) + yg.visualize.save_rhat_iteration_AD_DL(data_AD["iteration"], data_AD["result"], data_DL["result"], out_fp=out_fp) \ No newline at end of file