Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise experiment Mark02 #120

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/pyggdrasil/tree_inference/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,23 @@ class MoveProbConfigOptions(Enum):


class McmcConfig(BaseModel):
"""Config for MCMC sampler."""
"""Config for MCMC sampler.

Attributes:

move_probs: MoveProbConfig
move probabilities for MCMC sampler
fpr: float
false positive rate
fnr: float
false negative rate
n_samples: int
number of samples to draw
burn_in: int
number of samples to discard as burn-in
thinning: int
thinning factor for samples
"""

move_probs: MoveProbConfig = MoveProbConfigOptions.DEFAULT.value
fpr: confloat(gt=0, lt=1) = 1.24e-06 # type: ignore
Expand Down
37 changes: 35 additions & 2 deletions tests/tree_inference/test_file_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CellSimulationId,
TreeType,
McmcRunId,
ErrorCombinations,
)


Expand Down Expand Up @@ -161,9 +162,41 @@ def test_huntrees_tree_id_from_str() -> None:
def test_mcmc_tree_id_from_str() -> None:
"""Tests for tree id."""

str = "iT_m_6_5_99_oT_r_6_42"
test_str = "iT_m_6_5_99_oT_r_6_42"

test_id: TreeId = TreeId.from_str(str) # type: ignore
test_id: TreeId = TreeId.from_str(test_str) # type: ignore

assert test_id.tree_type == TreeType.MCMC
assert test_id.n_nodes == 6


def test_mcmc_id_from_string_manual() -> None:
test_str = (
"MCMC_35-CS_42-T_r_31_42-1000_1e-06_1e-06_0.0_f"
"_UXR-iT_r_31_35-MC_1e-06_1e-06_2000_0_1-MPC_0.1_0.65_0.25"
)

true_tree_id = TreeId.from_str("T_r_31_42")

cs_id = CellSimulationId(
42,
true_tree_id, # type: ignore
1000,
1e-06,
1e-06,
0.0,
False,
CellAttachmentStrategy.UNIFORM_EXCLUDE_ROOT,
)

init_tree_id = TreeId.from_str("T_r_31_35")

move_probs = MoveProbConfig()
err = ErrorCombinations.IDEAL.value
mcmc_config = McmcConfig(
fnr=err.fnr, fpr=err.fpr, move_probs=move_probs, n_samples=2000
)

test_id: McmcRunId = McmcRunId(35, cs_id, init_tree_id, mcmc_config) # type: ignore

assert str(test_id) == test_str
254 changes: 252 additions & 2 deletions workflows/mark03.smk
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ from pyggdrasil.tree_inference import CellSimulationId, TreeType, TreeId, McmcCo

#####################
# Environment variables
DATADIR = "../data"
# DATADIR = "/cluster/work/bewi/members/gkoehn/data"
#DATADIR = "../data"
DATADIR = "/cluster/work/bewi/members/gkoehn/data"

#####################
experiment = "mark03"
Expand Down Expand Up @@ -180,6 +180,7 @@ def make_all_mark03():
).id()
# make filepaths for each metric
for each_metric in metrics:
# with huntress
filepaths.append(
filepath
+ mc
Expand All @@ -191,6 +192,18 @@ def make_all_mark03():
+ each_metric
+ "_iter.svg"
)
# without huntress
filepaths.append(
filepath
+ mc
+ "/"
+ str(cs)
+ "/"
+ str(true_tree_id)
+ "/"
+ each_metric
+ "_iter_noHuntress.svg"
)
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved
return filepaths


Expand Down Expand Up @@ -452,3 +465,240 @@ rule combined_logProb_iteration_plot:

# save the histogram
fig.savefig(Path(output.combined_logP_iter))


def make_combined_metric_iteration_in_noHuntress():
"""Make input for combined_metric_iteration rule. - no huntress"""
input = []
tree_type = []

for mcmc_seed, init_tree_type, init_tree_seed in initial_points:
# make variables strings dependent on tree type
# catch the case where init_tree_type is star tree
if init_tree_type == "s":
input.append(
"{DATADIR}/mark03/analysis/MCMC_"
+ str(mcmc_seed)
+ "-{mutation_data_id}-iT_"
+ str(init_tree_type)
+ "_{n_nodes,\d+}"
+ "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json"
)
# catch the case where init_tree_type is huntress tree
elif init_tree_type == "h":
continue
# if mcmc tree
elif init_tree_type == "m":
# split the mcmc seed int into 2 parts: tree_seed, mcmc_seed
tree_seed, mcmc_move_seed = init_tree_seed // 100, init_tree_seed % 100
input.append(
"{DATADIR}/mark03/analysis/MCMC_"
+ str(mcmc_seed)
+ "-{mutation_data_id}-"
+ "iT_m_{n_nodes}_"
+ str(n_mcmc_tree_moves)
+ "_"
+ str(mcmc_move_seed)
+ "_oT_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}"
+ "-{mcmc_config_id}"
+ "/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json"
)
# all other cases
else:
input.append(
"{DATADIR}/mark03/analysis/MCMC_"
+ str(mcmc_seed)
+ "-{mutation_data_id}-iT_"
+ str(init_tree_type)
+ "_{n_nodes,\d+}_"
+ str(init_tree_seed)
+ "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}.json"
)
tree_type.append(init_tree_type)

return input, tree_type


rule combined_metric_iteration_plot_noHuntress:
"""Make combined metric iteration plot - no Huntress.

For each metric, make a plot with all the chains, where
each initial tree type is a different color.
"""
input:
# calls analyze_metric rule
all_chain_metrics=make_combined_metric_iteration_in()[0],
wildcard_constraints:
# metric wildcard cannot be log_prob
metric=r"(?!(log_prob))\w+",
output:
combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/"
"T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter_noHuntress.svg",
run:
# load the data
distances_chains = []
# get the initial tree type, same order as the input
initial_tree_type = make_combined_metric_iteration_in()[1]
# for each chain
for each_chain_metric in input.all_chain_metrics:
# load the distances
_, distances = yg.serialize.read_metric_result(Path(each_chain_metric))
# append to the list
distances_chains.append(distances)

# Create a figure and axis
fig, ax = plt.subplots()

# Define the list of colors to repeat
colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"}
labels = {
"h": "Huntress",
"s": "Star",
"d": "Deep",
"r": "Random",
"m": "MCMC5",
}

# Define opacity and line style
alpha = 0.6
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis
for i, distances in enumerate(distances_chains):
color = colors[initial_tree_type[i]]
ax.plot(
distances,
color=color,
label=f"{labels[initial_tree_type[i]]}",
alpha=alpha,
linestyle=line_style,
)

# Set labels and title
ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}")
ax.set_xlabel("Iteration")

# Add a legend of fixed legend position and size
ax.legend(loc="upper right")

# save the histogram
fig.savefig(Path(output.combined_metric_iter))



def make_combined_log_prob_iteration_in_noHuntress():
"""Make input for combined_metric_iteration rule - no huntress."""
input = []

for mcmc_seed, init_tree_type, init_tree_seed in initial_points:
# make variables strings dependent on tree type
# catch the case where init_tree_type is star tree
if init_tree_type == "s":
input.append(
"{DATADIR}/mark03/analysis/MCMC_"
+ str(mcmc_seed)
+ "-{mutation_data_id}-iT_"
+ str(init_tree_type)
+ "_{n_nodes,\d+}"
+ "-{mcmc_config_id}/log_prob.json"
)
# catch the case where init_tree_type is huntress tree
elif init_tree_type == "h":
continue
# if mcmc tree
elif init_tree_type == "m":
# split the mcmc seed int into 2 parts: tree_seed, mcmc_seed
tree_seed, mcmc_move_seed = init_tree_seed // 100, init_tree_seed % 100
input.append(
"{DATADIR}/mark03/analysis/MCMC_"
+ str(mcmc_seed)
+ "-{mutation_data_id}-"
+ "iT_m_{n_nodes}_"
+ str(n_mcmc_tree_moves)
+ "_"
+ str(mcmc_move_seed)
+ "_oT_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}"
+ "-{mcmc_config_id}"
+ "/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json"
)

# all other cases
else:
input.append(
"{DATADIR}/mark03/analysis/MCMC_"
+ str(mcmc_seed)
+ "-{mutation_data_id}-iT_"
+ str(init_tree_type)
+ "_{n_nodes,\d+}_"
+ str(init_tree_seed)
+ "-{mcmc_config_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob.json"
)
return input


rule combined_logProb_iteration_plot_noHuntress:
"""Make combined logProb iteration plot. - excludes huntress"""
input:
# calls analyze_metric rule
all_chain_logProb=make_combined_log_prob_iteration_in_noHuntress(),
output:
combined_logP_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/log_prob_iter_noHuntress.svg",
run:
# load the data
logP_chains = []
# get the initial tree type, same order as the input
initial_tree_type = make_combined_metric_iteration_in()[1]
# for each chain
for each_chain_metric in input.all_chain_logProb:
# load the distances
_, logP = yg.serialize.read_metric_result(Path(each_chain_metric))
# append to the list
logP_chains.append(logP)

# Create a figure and axis
fig, ax = plt.subplots()

# Define the list of colors to repeat
colors = {
"s": "green",
"d": "blue",
"r": "orange",
"mcmc": "purple",
}

labels = {
"s": "Star",
"d": "Deep",
"r": "Random",
"mcmc": "MCMC5",
}

# Define opacity and line style
alpha = 0.6
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis
for i, logP in enumerate(logP_chains):
color = colors[initial_tree_type[i]]
ax.plot(
logP,
color=color,
label=f"{labels[initial_tree_type[i]]}",
alpha=alpha,
linestyle=line_style,
)

# Set labels and title
ax.set_ylabel(f"Log Probability:" + r"$\log(P(D|T,\theta))$")
ax.set_xlabel("Iteration")
gordonkoehn marked this conversation as resolved.
Show resolved Hide resolved

# Add a legend of fixed legend position
ax.legend(loc="upper right")

# save the histogram
fig.savefig(Path(output.combined_logP_iter))
4 changes: 2 additions & 2 deletions workflows/tree_inference.smk
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ from pyggdrasil.tree_inference import (
###############################################
## Relative path from DATADIR to the repo root

#REPODIR = "/cluster/work/bewi/members/gkoehn/repos/PYggdrasil"
REPODIR = ".."
REPODIR = "/cluster/work/bewi/members/gkoehn/repos/PYggdrasil"
#REPODIR = ".."

###############################################

Expand Down