From 121aaf6a9710201c224c545cbe2092a8a2fd9a79 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Sat, 25 Nov 2023 09:30:04 -0800 Subject: [PATCH] new wandb to plot script --- scripts/wandb_to_plot.py | 62 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 scripts/wandb_to_plot.py diff --git a/scripts/wandb_to_plot.py b/scripts/wandb_to_plot.py new file mode 100644 index 00000000..49fa12f9 --- /dev/null +++ b/scripts/wandb_to_plot.py @@ -0,0 +1,62 @@ +import argparse +import re +from necessary import necessary + +with necessary(["plotly", "wandb"]): + import plotly + import wandb + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument( + "-t", "--wandb-team", type=str, required=True, help="Name of the wandb team to use, e.g. 'ai2-llm'" + ) + ap.add_argument( + "-p", + "--wandb-project", + type=str, + required=True, + help="Name of the wandb project to use, e.g. 'olmo-small'", + ) + ap.add_argument( + "-n", "--wandb-name", type=str, required=True, help="Run name or regex to use, e.g. '3T-lower-tie-.*'" + ) + ap.add_argument( + "-x", + "--x-axis", + type=str, + default="throughput/total_tokens", + ) + ap.add_argument( + "-y", + "--y-axis", + nargs="+", + type=str, + default=["train/Perplexity"] + ) + return ap.parse_args() + + +def main(): + opts = parse_args() + + # make sure we're logged in + wandb.login() + + api = wandb.Api() + runs = api.runs(f"{opts.wandb_team}/{opts.wandb_project}") + + re_run_name = re.compile(opts.wandb_name) + + for wb_run in runs: + if not re_run_name.search(wb_run.name): + continue + + import ipdb + + ipdb.set_trace() + + +if __name__ == "__main__": + main()