diff --git a/src/pyggdrasil/tree_inference/_config.py b/src/pyggdrasil/tree_inference/_config.py index bc9aeecc..f5a876ea 100644 --- a/src/pyggdrasil/tree_inference/_config.py +++ b/src/pyggdrasil/tree_inference/_config.py @@ -83,7 +83,23 @@ class MoveProbConfigOptions(Enum): class McmcConfig(BaseModel): - """Config for MCMC sampler.""" + """Config for MCMC sampler. + + Attributes: + + move_probs: MoveProbConfig + move probabilities for MCMC sampler + fpr: float + false positive rate + fnr: float + false negative rate + n_samples: int + number of samples to draw + burn_in: int + number of samples to discard as burn-in + thinning: int + thinning factor for samples + """ move_probs: MoveProbConfig = MoveProbConfigOptions.DEFAULT.value fpr: confloat(gt=0, lt=1) = 1.24e-06 # type: ignore diff --git a/tests/tree_inference/test_file_id.py b/tests/tree_inference/test_file_id.py index a29d8b69..269008eb 100644 --- a/tests/tree_inference/test_file_id.py +++ b/tests/tree_inference/test_file_id.py @@ -11,6 +11,7 @@ CellSimulationId, TreeType, McmcRunId, + ErrorCombinations, ) @@ -161,9 +162,41 @@ def test_huntrees_tree_id_from_str() -> None: def test_mcmc_tree_id_from_str() -> None: """Tests for tree id.""" - str = "iT_m_6_5_99_oT_r_6_42" + test_str = "iT_m_6_5_99_oT_r_6_42" - test_id: TreeId = TreeId.from_str(str) # type: ignore + test_id: TreeId = TreeId.from_str(test_str) # type: ignore assert test_id.tree_type == TreeType.MCMC assert test_id.n_nodes == 6 + + +def test_mcmc_id_from_string_manual() -> None: + test_str = ( + "MCMC_35-CS_42-T_r_31_42-1000_1e-06_1e-06_0.0_f" + "_UXR-iT_r_31_35-MC_1e-06_1e-06_2000_0_1-MPC_0.1_0.65_0.25" + ) + + true_tree_id = TreeId.from_str("T_r_31_42") + + cs_id = CellSimulationId( + 42, + true_tree_id, # type: ignore + 1000, + 1e-06, + 1e-06, + 0.0, + False, + CellAttachmentStrategy.UNIFORM_EXCLUDE_ROOT, + ) + + init_tree_id = TreeId.from_str("T_r_31_35") + + move_probs = MoveProbConfig() + err = ErrorCombinations.IDEAL.value + mcmc_config = McmcConfig( + fnr=err.fnr, fpr=err.fpr, move_probs=move_probs, n_samples=2000 + ) + + test_id: McmcRunId = McmcRunId(35, cs_id, init_tree_id, mcmc_config) # type: ignore + + assert str(test_id) == test_str diff --git a/workflows/mark03.smk b/workflows/mark03.smk index 584ba9a6..babd0156 100644 --- a/workflows/mark03.smk +++ b/workflows/mark03.smk @@ -14,10 +14,13 @@ import pyggdrasil as yg from pyggdrasil.tree_inference import CellSimulationId, TreeType, TreeId, McmcConfig +# TODO (gordonkoehn): Issue #121: many rules are similar here, all rely on the ploting of the log/prob iteration, +# and metric iteration. We should make a generic rule for this, and then have the other rules + ##################### # Environment variables -DATADIR = "../data" -# DATADIR = "/cluster/work/bewi/members/gkoehn/data" +#DATADIR = "../data" +DATADIR = "/cluster/work/bewi/members/gkoehn/data" ##################### experiment = "mark03" @@ -180,6 +183,7 @@ def make_all_mark03(): ).id() # make filepaths for each metric for each_metric in metrics: + # with huntress filepaths.append( filepath + mc @@ -191,6 +195,18 @@ def make_all_mark03(): + each_metric + "_iter.svg" ) + # without huntress + filepaths.append( + filepath + + mc + + "/" + + str(cs) + + "/" + + str(true_tree_id) + + "/" + + each_metric + + "_iter_noHuntress.svg" + ) return filepaths @@ -452,3 +468,240 @@ rule combined_logProb_iteration_plot: # save the histogram fig.savefig(Path(output.combined_logP_iter)) + + +def make_combined_metric_iteration_in_noHuntress(): + """Make input for combined_metric_iteration rule. - no huntress""" + input = [] + tree_type = [] + + for mcmc_seed, init_tree_type, init_tree_seed in initial_points: + # make variables strings dependent on tree type + # catch the case where init_tree_type is star tree + if init_tree_type == "s": + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}" + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + # catch the case where init_tree_type is huntress tree + elif init_tree_type == "h": + continue + # if mcmc tree + elif init_tree_type == "m": + # split the mcmc seed int into 2 parts: tree_seed, mcmc_seed + tree_seed, mcmc_move_seed = init_tree_seed // 100, init_tree_seed % 100 + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-" + + "iT_m_{n_nodes}_" + + str(n_mcmc_tree_moves) + + "_" + + str(mcmc_move_seed) + + "_oT_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}" + + "-{mcmc_config_id}" + + "/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + # all other cases + else: + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}_" + + str(init_tree_seed) + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json" + ) + tree_type.append(init_tree_type) + + return input, tree_type + + +rule combined_metric_iteration_plot_noHuntress: + """Make combined metric iteration plot - no Huntress. + + For each metric, make a plot with all the chains, where + each initial tree type is a different color. + """ + input: + # calls analyze_metric rule + all_chain_metrics=make_combined_metric_iteration_in()[0], + wildcard_constraints: + # metric wildcard cannot be log_prob + metric=r"(?!(log_prob))\w+", + output: + combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/" + "T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter_noHuntress.svg", + run: + # load the data + distances_chains = [] + # get the initial tree type, same order as the input + initial_tree_type = make_combined_metric_iteration_in()[1] + # for each chain + for each_chain_metric in input.all_chain_metrics: + # load the distances + _, distances = yg.serialize.read_metric_result(Path(each_chain_metric)) + # append to the list + distances_chains.append(distances) + + # Create a figure and axis + fig, ax = plt.subplots() + + # Define the list of colors to repeat + colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"} + labels = { + "h": "Huntress", + "s": "Star", + "d": "Deep", + "r": "Random", + "m": "MCMC5", + } + + # Define opacity and line style + alpha = 0.6 + line_style = "solid" + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + for i, distances in enumerate(distances_chains): + color = colors[initial_tree_type[i]] + ax.plot( + distances, + color=color, + label=f"{labels[initial_tree_type[i]]}", + alpha=alpha, + linestyle=line_style, + ) + + # Set labels and title + ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}") + ax.set_xlabel("Iteration") + + # Add a legend of fixed legend position and size + ax.legend(loc="upper right") + + # save the histogram + fig.savefig(Path(output.combined_metric_iter)) + + + +def make_combined_log_prob_iteration_in_noHuntress(): + """Make input for combined_metric_iteration rule - no huntress.""" + input = [] + + for mcmc_seed, init_tree_type, init_tree_seed in initial_points: + # make variables strings dependent on tree type + # catch the case where init_tree_type is star tree + if init_tree_type == "s": + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}" + + "-{mcmc_config_id}/log_prob.json" + ) + # catch the case where init_tree_type is huntress tree + elif init_tree_type == "h": + continue + # if mcmc tree + elif init_tree_type == "m": + # split the mcmc seed int into 2 parts: tree_seed, mcmc_seed + tree_seed, mcmc_move_seed = init_tree_seed // 100, init_tree_seed % 100 + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-" + + "iT_m_{n_nodes}_" + + str(n_mcmc_tree_moves) + + "_" + + str(mcmc_move_seed) + + "_oT_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}" + + "-{mcmc_config_id}" + + "/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json" + ) + + # all other cases + else: + input.append( + "{DATADIR}/mark03/analysis/MCMC_" + + str(mcmc_seed) + + "-{mutation_data_id}-iT_" + + str(init_tree_type) + + "_{n_nodes,\d+}_" + + str(init_tree_seed) + + "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json" + ) + return input + + +rule combined_logProb_iteration_plot_noHuntress: + """Make combined logProb iteration plot. - excludes huntress""" + input: + # calls analyze_metric rule + all_chain_logProb=make_combined_log_prob_iteration_in_noHuntress(), + output: + combined_logP_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob_iter_noHuntress.svg", + run: + # load the data + logP_chains = [] + # get the initial tree type, same order as the input + initial_tree_type = make_combined_metric_iteration_in()[1] + # for each chain + for each_chain_metric in input.all_chain_logProb: + # load the distances + _, logP = yg.serialize.read_metric_result(Path(each_chain_metric)) + # append to the list + logP_chains.append(logP) + + # Create a figure and axis + fig, ax = plt.subplots() + + # Define the list of colors to repeat + colors = { + "s": "green", + "d": "blue", + "r": "orange", + "mcmc": "purple", + } + + labels = { + "s": "Star", + "d": "Deep", + "r": "Random", + "mcmc": "MCMC5", + } + + # Define opacity and line style + alpha = 0.6 + line_style = "solid" + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + for i, logP in enumerate(logP_chains): + color = colors[initial_tree_type[i]] + ax.plot( + logP, + color=color, + label=f"{labels[initial_tree_type[i]]}", + alpha=alpha, + linestyle=line_style, + ) + + # Set labels and title + ax.set_ylabel(f"Log Probability:" + r"$\log(P(D|T,\theta))$") + ax.set_xlabel("Iteration") + + # Add a legend of fixed legend position + ax.legend(loc="upper right") + + # save the histogram + fig.savefig(Path(output.combined_logP_iter)) \ No newline at end of file