From 9f6ed37d988c9c93742ea20876b83bedd38886ae Mon Sep 17 00:00:00 2001 From: gordonkoehn Date: Wed, 26 Jul 2023 14:57:27 +0200 Subject: [PATCH] fix metric plots --- workflows/mark03.smk | 192 +++++++++++++++++++------------------------ 1 file changed, 84 insertions(+), 108 deletions(-) diff --git a/workflows/mark03.smk b/workflows/mark03.smk index 9c97f562..fcdea162 100644 --- a/workflows/mark03.smk +++ b/workflows/mark03.smk @@ -276,6 +276,85 @@ def make_combined_metric_iteration_in(): return input, tree_type +def legend_without_duplicate_labels(figure): + """Add a legend to a figure without duplicate labels.""" + handles, labels = plt.gca().get_legend_handles_labels() + by_label = dict(zip(labels, handles)) + figure.legend(by_label.values(), by_label.keys(), loc='upper right') + + +def plot_iteration_metric(all_chain_metrics : list[str], metric : str, output_path : str, initial_tree_type : list) : + """Make combined metric iteration plot. + + Args: + all_chain_metrics: list[str] + A list of filepaths to the metric json files. + metric: str + The metric to plot. + output_path: str + The output path to save the plot to. + initial_tree_type: list + A list of the initial tree types, in the same order as the input. + + Returns: + None + """ + + # load the data + distances_chains = [] + # get the initial tree type, same order as the input + initial_tree_type = initial_tree_type + # for each chain + for each_chain_metric in all_chain_metrics: + # load the distances + _, distances = yg.serialize.read_metric_result(Path(each_chain_metric)) + # append to the list + distances_chains.append(distances) + + # Create a figure and axis + fig, ax = plt.subplots() + + # Define the list of colors to repeat + colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"} + labels = { + "h": "Huntress", + "s": "Star", + "d": "Deep", + "r": "Random", + "m": "MCMC5", + } + + # Define opacity and line style + alpha = 0.4 + line_style = "solid" + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + + # Plot each entry of distance chain as a line with a color unique to the + # initial tree type onto one axis + for i, distances in enumerate(distances_chains): + color = colors[initial_tree_type[i]] + ax.plot( + distances, + color=color, + label=f"{labels[initial_tree_type[i]]}", + alpha=alpha, + linestyle=line_style, + ) + + # Set labels and title + ax.set_ylabel(f"Distance/Similarity: {metric}") + ax.set_xlabel("Iteration") + + # Add a legend of fixed legend position and size + #ax.legend(loc="upper right") + legend_without_duplicate_labels(plt) + + # save the histogram + fig.savefig(Path(output_path)) + + rule combined_metric_iteration_plot: """Make combined metric iteration plot. @@ -292,58 +371,7 @@ rule combined_metric_iteration_plot: combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/" "T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter.svg", run: - # load the data - distances_chains = [] - # get the initial tree type, same order as the input - initial_tree_type = make_combined_metric_iteration_in()[1] - # for each chain - for each_chain_metric in input.all_chain_metrics: - # load the distances - _, distances = yg.serialize.read_metric_result(Path(each_chain_metric)) - # append to the list - distances_chains.append(distances) - - # Create a figure and axis - fig, ax = plt.subplots() - - # Define the list of colors to repeat - colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"} - labels = { - "h": "Huntress", - "s": "Star", - "d": "Deep", - "r": "Random", - "m": "MCMC5", - } - - # Define opacity and line style - alpha = 0.6 - line_style = "solid" - - # Plot each entry of distance chain as a line with a color unique to the - # initial tree type onto one axis - - # Plot each entry of distance chain as a line with a color unique to the - # initial tree type onto one axis - for i, distances in enumerate(distances_chains): - color = colors[initial_tree_type[i]] - ax.plot( - distances, - color=color, - label=f"{labels[initial_tree_type[i]]}", - alpha=alpha, - linestyle=line_style, - ) - - # Set labels and title - ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}") - ax.set_xlabel("Iteration") - - # Add a legend of fixed legend position and size - ax.legend(loc="upper right") - - # save the histogram - fig.savefig(Path(output.combined_metric_iter)) + plot_iteration_metric(input.all_chain_metrics, wildcards.metric, output.combined_metric_iter, make_combined_metric_iteration_in()[1]) def make_combined_log_prob_iteration_in(): @@ -444,7 +472,7 @@ rule combined_logProb_iteration_plot: } # Define opacity and line style - alpha = 0.6 + alpha = 0.4 line_style = "solid" # Plot each entry of distance chain as a line with a color unique to the @@ -530,7 +558,7 @@ rule combined_metric_iteration_plot_noHuntress: """ input: # calls analyze_metric rule - all_chain_metrics=make_combined_metric_iteration_in()[0], + all_chain_metrics=make_combined_metric_iteration_in_noHuntress()[0], wildcard_constraints: # metric wildcard cannot be log_prob metric=r"(?!(log_prob))\w+", @@ -538,59 +566,7 @@ rule combined_metric_iteration_plot_noHuntress: combined_metric_iter="{DATADIR}/{experiment}/plots/{mcmc_config_id}/{mutation_data_id}/" "T_{base_tree_type}_{n_nodes,\d+}_{base_tree_seed,\d+}/{metric}_iter_noHuntress.svg", run: - # load the data - distances_chains = [] - # get the initial tree type, same order as the input - initial_tree_type = make_combined_metric_iteration_in()[1] - # for each chain - for each_chain_metric in input.all_chain_metrics: - # load the distances - _, distances = yg.serialize.read_metric_result(Path(each_chain_metric)) - # append to the list - distances_chains.append(distances) - - # Create a figure and axis - fig, ax = plt.subplots() - - # Define the list of colors to repeat - colors = {"h": "red", "s": "green", "d": "blue", "r": "orange", "m": "purple"} - labels = { - "h": "Huntress", - "s": "Star", - "d": "Deep", - "r": "Random", - "m": "MCMC5", - } - - # Define opacity and line style - alpha = 0.6 - line_style = "solid" - - # Plot each entry of distance chain as a line with a color unique to the - # initial tree type onto one axis - - # Plot each entry of distance chain as a line with a color unique to the - # initial tree type onto one axis - for i, distances in enumerate(distances_chains): - color = colors[initial_tree_type[i]] - ax.plot( - distances, - color=color, - label=f"{labels[initial_tree_type[i]]}", - alpha=alpha, - linestyle=line_style, - ) - - # Set labels and title - ax.set_ylabel(f"Distance/Similarity: {wildcards.metric}") - ax.set_xlabel("Iteration") - - # Add a legend of fixed legend position and size - ax.legend(loc="upper right") - - # save the histogram - fig.savefig(Path(output.combined_metric_iter)) - + plot_iteration_metric(input.all_chain_metrics,wildcards.metric,output.combined_metric_iter, make_combined_metric_iteration_in_noHuntress()[1]) def make_combined_log_prob_iteration_in_noHuntress(): @@ -681,7 +657,7 @@ rule combined_logProb_iteration_plot_noHuntress: } # Define opacity and line style - alpha = 0.6 + alpha = 0.4 line_style = "solid" # Plot each entry of distance chain as a line with a color unique to the