Skip to content

Commit

Permalink
fix metric plots
Browse files Browse the repository at this point in the history
  • Loading branch information
gordonkoehn committed Jul 26, 2023
1 parent 85a1a6f commit 9f6ed37
Showing 1 changed file with 84 additions and 108 deletions.
192 changes: 84 additions & 108 deletions workflows/mark03.smk
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,85 @@ def make_combined_metric_iteration_in():
return input, tree_type


def legend_without_duplicate_labels(figure):
"""Add a legend to a figure without duplicate labels."""
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
figure.legend(by_label.values(), by_label.keys(), loc='upper right')


def plot_iteration_metric(all_chain_metrics : list[str], metric : str, output_path : str, initial_tree_type : list) :
"""Make combined metric iteration plot.
Args:
all_chain_metrics: list[str]
A list of filepaths to the metric json files.
metric: str
The metric to plot.
output_path: str
The output path to save the plot to.
initial_tree_type: list
A list of the initial tree types, in the same order as the input.
Returns:
None
"""

# load the data
distances_chains = []
# get the initial tree type, same order as the input
initial_tree_type = initial_tree_type
# for each chain
for each_chain_metric in all_chain_metrics:
# load the distances
_, distances = yg.serialize.read_metric_result(Path(each_chain_metric))
# append to the list
distances_chains.append(distances)

# Create a figure and axis
fig, ax = plt.subplots()

# Define the list of colors to repeat
colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"}
labels = {
"h": "Huntress",
"s": "Star",
"d": "Deep",
"r": "Random",
"m": "MCMC5",
}

# Define opacity and line style
alpha = 0.4
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis
for i, distances in enumerate(distances_chains):
color = colors[initial_tree_type[i]]
ax.plot(
distances,
color=color,
label=f"{labels[initial_tree_type[i]]}",
alpha=alpha,
linestyle=line_style,
)

# Set labels and title
ax.set_ylabel(f"Distance/Similarity: {metric}")
ax.set_xlabel("Iteration")

# Add a legend of fixed legend position and size
#ax.legend(loc="upper right")
legend_without_duplicate_labels(plt)

# save the histogram
fig.savefig(Path(output_path))


rule combined_metric_iteration_plot:
"""Make combined metric iteration plot.
Expand All @@ -292,58 +371,7 @@ rule combined_metric_iteration_plot:
combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/"
"T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter.svg",
run:
# load the data
distances_chains = []
# get the initial tree type, same order as the input
initial_tree_type = make_combined_metric_iteration_in()[1]
# for each chain
for each_chain_metric in input.all_chain_metrics:
# load the distances
_, distances = yg.serialize.read_metric_result(Path(each_chain_metric))
# append to the list
distances_chains.append(distances)

# Create a figure and axis
fig, ax = plt.subplots()

# Define the list of colors to repeat
colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"}
labels = {
"h": "Huntress",
"s": "Star",
"d": "Deep",
"r": "Random",
"m": "MCMC5",
}

# Define opacity and line style
alpha = 0.6
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis
for i, distances in enumerate(distances_chains):
color = colors[initial_tree_type[i]]
ax.plot(
distances,
color=color,
label=f"{labels[initial_tree_type[i]]}",
alpha=alpha,
linestyle=line_style,
)

# Set labels and title
ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}")
ax.set_xlabel("Iteration")

# Add a legend of fixed legend position and size
ax.legend(loc="upper right")

# save the histogram
fig.savefig(Path(output.combined_metric_iter))
plot_iteration_metric(input.all_chain_metrics, wildcards.metric, output.combined_metric_iter, make_combined_metric_iteration_in()[1])


def make_combined_log_prob_iteration_in():
Expand Down Expand Up @@ -444,7 +472,7 @@ rule combined_logProb_iteration_plot:
}

# Define opacity and line style
alpha = 0.6
alpha = 0.4
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
Expand Down Expand Up @@ -530,67 +558,15 @@ rule combined_metric_iteration_plot_noHuntress:
"""
input:
# calls analyze_metric rule
all_chain_metrics=make_combined_metric_iteration_in()[0],
all_chain_metrics=make_combined_metric_iteration_in_noHuntress()[0],
wildcard_constraints:
# metric wildcard cannot be log_prob
metric=r"(?!(log_prob))\w+",
output:
combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/"
"T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter_noHuntress.svg",
run:
# load the data
distances_chains = []
# get the initial tree type, same order as the input
initial_tree_type = make_combined_metric_iteration_in()[1]
# for each chain
for each_chain_metric in input.all_chain_metrics:
# load the distances
_, distances = yg.serialize.read_metric_result(Path(each_chain_metric))
# append to the list
distances_chains.append(distances)

# Create a figure and axis
fig, ax = plt.subplots()

# Define the list of colors to repeat
colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"}
labels = {
"h": "Huntress",
"s": "Star",
"d": "Deep",
"r": "Random",
"m": "MCMC5",
}

# Define opacity and line style
alpha = 0.6
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis

# Plot each entry of distance chain as a line with a color unique to the
# initial tree type onto one axis
for i, distances in enumerate(distances_chains):
color = colors[initial_tree_type[i]]
ax.plot(
distances,
color=color,
label=f"{labels[initial_tree_type[i]]}",
alpha=alpha,
linestyle=line_style,
)

# Set labels and title
ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}")
ax.set_xlabel("Iteration")

# Add a legend of fixed legend position and size
ax.legend(loc="upper right")

# save the histogram
fig.savefig(Path(output.combined_metric_iter))

plot_iteration_metric(input.all_chain_metrics,wildcards.metric,output.combined_metric_iter, make_combined_metric_iteration_in_noHuntress()[1])


def make_combined_log_prob_iteration_in_noHuntress():
Expand Down Expand Up @@ -681,7 +657,7 @@ rule combined_logProb_iteration_plot_noHuntress:
}

# Define opacity and line style
alpha = 0.6
alpha = 0.4
line_style = "solid"

# Plot each entry of distance chain as a line with a color unique to the
Expand Down

0 comments on commit 9f6ed37

Please sign in to comment.