forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
simple jupyter notebook example on how to parse and vis the logfiles
- Loading branch information
Showing
2 changed files
with
125 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Simple visualizer for log files written by the training loop" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"%matplotlib inline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def parse_logfile(logfile):\n", | ||
" streams = {} # stream:str -> (steps[], values[])\n", | ||
" with open(logfile, \"r\") as f:\n", | ||
" for line in f:\n", | ||
" parts = line.split()\n", | ||
" assert len(parts) == 2\n", | ||
" step = int(parts[0].split(\":\")[1])\n", | ||
" stream = parts[1].split(\":\")[0]\n", | ||
" val = float(parts[1].split(\":\")[1])\n", | ||
" if not stream in streams:\n", | ||
" streams[stream] = ([], [])\n", | ||
" xs, ys = streams[stream]\n", | ||
" xs.append(step)\n", | ||
" ys.append(val)\n", | ||
" return streams\n", | ||
"\n", | ||
"# parse_logfile(\"../log124M/main.log\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sz = \"124M\"\n", | ||
"loss_baseline = {\n", | ||
" \"124M\": 3.424958,\n", | ||
" \"350M\": None,\n", | ||
" \"774M\": None,\n", | ||
" \"1558M\": None,\n", | ||
"}[sz]\n", | ||
"hella_baseline = {\n", | ||
" \"124M\": 0.2955,\n", | ||
" \"350M\": None,\n", | ||
" \"774M\": None,\n", | ||
" \"1558M\": None,\n", | ||
"}[sz]\n", | ||
"\n", | ||
"# assumes each model run is stored in this way\n", | ||
"logfile = f\"../log{sz}/main.log\"\n", | ||
"streams = parse_logfile(logfile)\n", | ||
"\n", | ||
"plt.figure(figsize=(16, 6))\n", | ||
"\n", | ||
"# Panel 1: losses: both train and val\n", | ||
"plt.subplot(121)\n", | ||
"xs, ys = streams[\"trl\"] # training loss\n", | ||
"plt.plot(xs, ys, label=f'llm.c ({sz}) train loss')\n", | ||
"print(\"Min Train Loss:\", min(ys))\n", | ||
"xs, ys = streams[\"tel\"] # validation loss\n", | ||
"plt.plot(xs, ys, label=f'llm.c ({sz}) val loss')\n", | ||
"# horizontal line at GPT-2 baseline\n", | ||
"if loss_baseline is not None:\n", | ||
" plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint val loss\")\n", | ||
"plt.xlabel(\"steps\")\n", | ||
"plt.ylabel(\"loss\")\n", | ||
"plt.yscale('log')\n", | ||
"plt.legend()\n", | ||
"plt.title(\"Loss\")\n", | ||
"print(\"Min Validation Loss:\", min(ys))\n", | ||
"\n", | ||
"# Panel 2: HellaSwag eval\n", | ||
"plt.subplot(122)\n", | ||
"xs, ys = streams[\"eval\"] # HellaSwag eval\n", | ||
"plt.plot(xs, ys, label=f\"llm.c ({sz})\")\n", | ||
"# horizontal line at GPT-2 baseline\n", | ||
"if hella_baseline:\n", | ||
" plt.axhline(y=hella_baseline, color='r', linestyle='--', label=f\"OpenAI GPT-2 ({sz}) checkpoint\")\n", | ||
"plt.xlabel(\"steps\")\n", | ||
"plt.ylabel(\"accuracy\")\n", | ||
"plt.legend()\n", | ||
"plt.title(\"HellaSwag eval\")\n", | ||
"print(\"Max Hellaswag eval:\", max(ys))\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "pytorch3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |