diff --git a/README.md b/README.md index 7b9c2d4fe..e8ed1702f 100644 --- a/README.md +++ b/README.md @@ -334,33 +334,7 @@ done # screen -ls | grep -E "tr[0-3]" | cut -d. -f1 | xargs -I {} screen -X -S {} quit ``` -This example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. Here's a quick example script to plot the losses in a Jupyter notebook, obviously can become more sophisticated later: - -```python -import matplotlib.pyplot as plt -%matplotlib inline - -def parse_log(logfile): - # look for lines like e.g. "s:100 tel:1.6952", step 100, val 1.6952 - val_steps, val_losses = [], [] - with open(logfile, "r") as f: - lines = f.readlines() - for line in lines: - if "tel" in line: - parts = line.split() - step = parts[0].split(":")[1] - loss = parts[1].split(":")[1] - val_steps.append(int(step)) - val_losses.append(float(loss)) - return val_steps, val_losses - -results = [parse_log(f"stories{i}.log") for i in range(0, 4)] -for i, (val_steps, val_losses) in enumerate(results): - plt.plot(val_steps, val_losses, label="run {}".format(i)) -plt.xlabel("steps") -plt.ylabel("loss") -plt.legend() -``` +This example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. A quick example of how to parse and plot these logfiles is in [dev/vislog.ipynb](dev/vislog.ipynb). ## repo philosophy diff --git a/dev/vislog.ipynb b/dev/vislog.ipynb new file mode 100644 index 000000000..c753831f5 --- /dev/null +++ b/dev/vislog.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Simple visualizer for log files written by the training loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def parse_logfile(logfile):\n", + " streams = {} # stream:str -> (steps[], values[])\n", + " with open(logfile, \"r\") as f:\n", + " for line in f:\n", + " parts = line.split()\n", + " assert len(parts) == 2\n", + " step = int(parts[0].split(\":\")[1])\n", + " stream = parts[1].split(\":\")[0]\n", + " val = float(parts[1].split(\":\")[1])\n", + " if not stream in streams:\n", + " streams[stream] = ([], [])\n", + " xs, ys = streams[stream]\n", + " xs.append(step)\n", + " ys.append(val)\n", + " return streams\n", + "\n", + "# parse_logfile(\"../log124M/main.log\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sz = \"124M\"\n", + "loss_baseline = {\n", + " \"124M\": 3.424958,\n", + " \"350M\": None,\n", + " \"774M\": None,\n", + " \"1558M\": None,\n", + "}[sz]\n", + "hella_baseline = {\n", + " \"124M\": 0.2955,\n", + " \"350M\": None,\n", + " \"774M\": None,\n", + " \"1558M\": None,\n", + "}[sz]\n", + "\n", + "# assumes each model run is stored in this way\n", + "logfile = f\"../log{sz}/main.log\"\n", + "streams = parse_logfile(logfile)\n", + "\n", + "plt.figure(figsize=(16, 6))\n", + "\n", + "# Panel 1: losses: both train and val\n", + "plt.subplot(121)\n", + "xs, ys = streams[\"trl\"] # training loss\n", + "plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')\n", + "print(\"Min Train Loss:\", min(ys))\n", + "xs, ys = streams[\"tel\"] # validation loss\n", + "plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')\n", + "# horizontal line at GPT-2 baseline\n", + "if loss_baseline is not None:\n", + " plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint val loss\")\n", + "plt.xlabel(\"steps\")\n", + "plt.ylabel(\"loss\")\n", + "plt.yscale('log')\n", + "plt.legend()\n", + "plt.title(\"Loss\")\n", + "print(\"Min Validation Loss:\", min(ys))\n", + "\n", + "# Panel 2: HellaSwag eval\n", + "plt.subplot(122)\n", + "xs, ys = streams[\"eval\"] # HellaSwag eval\n", + "plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n", + "# horizontal line at GPT-2 baseline\n", + "if hella_baseline:\n", + " plt.axhline(y=hella_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n", + "plt.xlabel(\"steps\")\n", + "plt.ylabel(\"accuracy\")\n", + "plt.legend()\n", + "plt.title(\"HellaSwag eval\")\n", + "print(\"Max Hellaswag eval:\", max(ys))\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}