From 4c6817624f68db1fcba6d67978c98607f552357d Mon Sep 17 00:00:00 2001 From: Sagar Vinodababu Date: Tue, 13 Feb 2024 16:34:04 -0700 Subject: [PATCH 1/2] more reliable file downloads --- CHANGELOG.md | 6 ++ README.md | 18 ++--- chess_transformers/configs/models/CT-E-20.py | 3 - chess_transformers/configs/models/CT-ED-45.py | 3 - .../configs/models/CT-EFT-20.py | 4 -- chess_transformers/play/human_play.ipynb | 23 ++++--- chess_transformers/play/utils.py | 69 ++++++++++++++++--- setup.py | 2 +- 8 files changed, 91 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 509b7a8..d6b684f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## v0.2.1 + +### Changed + +* All model checkpoints, datasets, and logs have now been moved to Microsoft Azure Storage for more reliable access. + ## v0.2.0 ### Added diff --git a/README.md b/README.md index 7a521e3..ecfa354 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@

Chess Transformers

Teaching transformers to play chess

-

Version License

+

Version License


*Chess Transformers* is a library for training transformer models to play chess by learning from human games. @@ -76,8 +76,8 @@ Detailed evaluation results for each model are provided below. ### *CT-E-20* -[**Configuration File**](chess_transformers/configs/models/CT-E-20.py) | [**Checkpoint**](https://drive.google.com/file/d/18Er4LbdujG-qiPPoqORvMQVcsiFerqY4/view?usp=drive_link) | -[**TensorBoard Logs**](https://drive.google.com/drive/folders/1WwWJS4804uKrONPcCtyQSGlOSUlo05-0?usp=drive_link) +[**Configuration File**](chess_transformers/configs/models/CT-E-20.py) | [**Checkpoint**](https://chesstransformers.blob.core.windows.net/checkpoints/CT-E-20/averaged_CT-E-20.pt) | +[**TensorBoard Logs**](https://chesstransformers.blob.core.windows.net/logs/CT-E-20.zip) This is the encoder from the original transformer model in [*Vaswani et al. (2017)*](https://arxiv.org/abs/1706.03762) trained on the [*LE1222*](#le1222) dataset. A classification head at the **`turn`** token predicts the best half-move to be made (in UCI notation). @@ -118,8 +118,8 @@ These evaluation games can be viewed [here](chess_transformers/eval/games/CT-E-2 ### *CT-EFT-20* -[**Configuration File**](chess_transformers/configs/models/CT-EFT-20.py) | [**Checkpoint**](https://drive.google.com/file/d/1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD/view?usp=drive_link) | -[**TensorBoard Logs**](https://drive.google.com/drive/folders/1gD-msDgMlRqjB7Y0DIGWZxKgsxjqIwQp?usp=drive_link) +[**Configuration File**](chess_transformers/configs/models/CT-EFT-20.py) | [**Checkpoint**](https://chesstransformers.blob.core.windows.net/checkpoints/CT-EFT-20/averaged_CT-EFT-20.pt) | +[**TensorBoard Logs**](https://chesstransformers.blob.core.windows.net/logs/CT-EFT-20.zip) This is the encoder from the original transformer model in [*Vaswani et al. (2017)*](https://arxiv.org/abs/1706.03762) trained on the [*LE1222*](#le1222) dataset. Two classification heads operate upon the encoder outputs at all chessboard squares to predict the best candidates for the source (*From*) and destination (*To*) squares that correspond to the best half-move to be made. @@ -160,8 +160,8 @@ These evaluation games can be viewed [here](chess_transformers/eval/games/CT-EFT ### *CT-ED-45* -[**Configuration File**](chess_transformers/configs/models/CT-ED-45.py) | [**Checkpoint**](https://drive.google.com/file/d/1zasRpPmZQVtAqumet9XMy1FBpmxxiM4L/view?usp=drive_link) | -[**TensorBoard Logs**](https://drive.google.com/drive/folders/1LGsKMsjFRjQBS56UJZTjegFMl7amIbSw?usp=drive_link) +[**Configuration File**](chess_transformers/configs/models/CT-ED-45.py) | [**Checkpoint**](https://chesstransformers.blob.core.windows.net/checkpoints/CT-ED-45/averaged_CT-ED-45.pt) | +[**TensorBoard Logs**](https://chesstransformers.blob.core.windows.net/logs/CT-ED-45.zip) This is the original transformer model (encoder *and* decoder) in [*Vaswani et al. (2017)*](https://arxiv.org/abs/1706.03762) trained on the [*LE1222*](#le1222) dataset. A classification head after the last decoder layer predicts a sequence of half-moves, starting with the best half-move to be made next, followed by the likely course of the game an arbitrary number of half-moves into the future. @@ -219,7 +219,7 @@ On this data, we apply the following filters to keep only those games that: These 274,794 games consist of a total **13,287,522 half-moves** made by the winners of the games, which alone constitute the dataset. For each such half-move, the chessboard, turn (white or black), and castling rights of both players before the move are calculated, as well as the sequence of half-moves beginning with this half-move up to 10 half-moves into the future. Draw potential is not calculated. -[**Download here.**](https://drive.google.com/drive/folders/17VrTNUbGXqCnK0d5oU3YJBsl6jybfVRM?usp=drive_link) The data is zipped and will need to be extracted. +[**Download here.**](https://chesstransformers.blob.core.windows.net/data/LE1222.zip) The data is zipped and will need to be extracted. It consists of the following files: @@ -244,7 +244,7 @@ On this data, we apply the following filters to keep only those games that: These 2,751,394 games consist of a total **127,684,720 half-moves** made by the winners of the games, which alone constitute the dataset. For each such half-move, the chessboard, turn (white or black), and castling rights of both players before the move are calculated, as well as the sequence of half-moves beginning with this half-move up to 10 half-moves into the future. Draw potential is not calculated. -[**Download here.**](https://drive.google.com/drive/folders/17VrTNUbGXqCnK0d5oU3YJBsl6jybfVRM?usp=drive_link) The data is zipped and will need to be extracted. +[**Download here.**](https://chesstransformers.blob.core.windows.net/data/LE1222x.zip) The data is zipped and will need to be extracted. It consists of the following files: diff --git a/chess_transformers/configs/models/CT-E-20.py b/chess_transformers/configs/models/CT-E-20.py index 845013f..905fa92 100644 --- a/chess_transformers/configs/models/CT-E-20.py +++ b/chess_transformers/configs/models/CT-E-20.py @@ -99,9 +99,6 @@ FINAL_CHECKPOINT = ( "averaged_" + NAME + ".pt" ) # final checkpoint to be used for eval/inference -FINAL_CHECKPOINT_GDID = ( - "18Er4LbdujG-qiPPoqORvMQVcsiFerqY4" # Google Drive ID for download -) ################################ ########## Evaluation ########## diff --git a/chess_transformers/configs/models/CT-ED-45.py b/chess_transformers/configs/models/CT-ED-45.py index 1251724..20c3618 100644 --- a/chess_transformers/configs/models/CT-ED-45.py +++ b/chess_transformers/configs/models/CT-ED-45.py @@ -99,9 +99,6 @@ FINAL_CHECKPOINT = ( "averaged_" + NAME + ".pt" ) # final checkpoint to be used for eval/inference -FINAL_CHECKPOINT_GDID = ( - "1zasRpPmZQVtAqumet9XMy1FBpmxxiM4L" # Google Drive ID for download -) ################################ ########## Evaluation ########## diff --git a/chess_transformers/configs/models/CT-EFT-20.py b/chess_transformers/configs/models/CT-EFT-20.py index 8b5d4b9..1d16052 100644 --- a/chess_transformers/configs/models/CT-EFT-20.py +++ b/chess_transformers/configs/models/CT-EFT-20.py @@ -99,10 +99,6 @@ FINAL_CHECKPOINT = ( "averaged_" + NAME + ".pt" ) # final checkpoint to be used for eval/inference -FINAL_CHECKPOINT_GDID = ( - "1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD" # Google Drive ID for download -) - ################################ ########## Evaluation ########## diff --git a/chess_transformers/play/human_play.ipynb b/chess_transformers/play/human_play.ipynb index a1d7ae2..681f0e9 100644 --- a/chess_transformers/play/human_play.ipynb +++ b/chess_transformers/play/human_play.ipynb @@ -28,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -53,17 +53,24 @@ "output_type": "stream", "text": [ "\n", - "Cannot find model checkpoint on disk; will download.\n" + "Cannot find model checkpoint on disk; will download.\n", + "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Downloading...\n", - "From: https://drive.google.com/uc?id=1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD\n", - "To: /home/sgr/projects/chess-transformers/chess_transformers/checkpoints/CT-EFT-20/averaged_CT-EFT-20.pt\n", - "100%|██████████| 75.9M/75.9M [00:01<00:00, 73.0MB/s]\n" + "averaged_CT-EFT-20.pt: 100%|██████████| 72.4M/72.4M [00:18<00:00, 4.00MB/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Model loaded!\n", + "\n" ] } ], @@ -80,7 +87,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": {}, "outputs": [ { diff --git a/chess_transformers/play/utils.py b/chess_transformers/play/utils.py index 27b15aa..00e968c 100644 --- a/chess_transformers/play/utils.py +++ b/chess_transformers/play/utils.py @@ -1,7 +1,7 @@ import os import sys import chess -import gdown +import urllib import pathlib import markdown import textwrap @@ -9,6 +9,7 @@ import chess.engine import torch.utils.data import torch.nn.functional as F +from tqdm import tqdm from datetime import date from tabulate import tabulate from bs4 import BeautifulSoup @@ -39,6 +40,36 @@ } +def download_file(url, output_path): + """ + Download the file at the given URL into a specified output path. + + Adapted from https://github.com/tqdm/tqdm#hooks-and-callbacks + + Args: + + url (str): The URL for the file to download. + + output_path (str): The path of the file to download into. + """ + + class TQDMUpTo(tqdm): + def update_to(self, b=1, bsize=1, tsize=None): + if tsize is not None: + self.total = tsize + return self.update(b * bsize - self.n) + + with TQDMUpTo( + unit="B", + unit_scale=True, + unit_divisor=1024, + miniters=1, + desc=url.split("/")[-1], + ) as t: + urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to) + t.total = t.n + + def load_model(CONFIG): """ Load model for inference. @@ -62,10 +93,25 @@ def load_model(CONFIG): checkpoint_folder.mkdir(parents=True, exist_ok=True) checkpoint_path = checkpoint_folder / CONFIG.FINAL_CHECKPOINT if not checkpoint_path.exists(): - print("\nCannot find model checkpoint on disk; will download.") - gdown.download( - id=CONFIG.FINAL_CHECKPOINT_GDID, output=str(checkpoint_path), quiet=False - ) + print("\nCannot find model checkpoint on disk; will download.\n") + download_file( + "ht" + + "tp" + + "s:" + + "//" + + "chess" + + "transformers" + + "." + + "blob" + + "." + + "core" + + "." + + "windows" + + "." + + "net" + + "/checkpoints/{}/{}".format(CONFIG.NAME, CONFIG.FINAL_CHECKPOINT), + str(checkpoint_path), + ) # scramble address against simple bots # Load checkpoint checkpoint = torch.load(str(checkpoint_path)) @@ -80,6 +126,8 @@ def load_model(CONFIG): ) model.eval() # eval mode disables dropout + print("\nModel loaded!\n") + return model @@ -100,9 +148,11 @@ def show_engine_options(engine): option.default, option.min, option.max, - "\n".join(textwrap.wrap(", ".join(option.var))) - if len(option.var) > 0 - else None, + ( + "\n".join(textwrap.wrap(", ".join(option.var))) + if len(option.var) > 0 + else None + ), ] ) display( @@ -123,7 +173,8 @@ def load_engine(path, show_options=True): path (str): The path to the engine file (an executable). - show_options (bool, optional): Print configurable UCI options for engine. + show_options (bool, optional): Print configurable UCI options + for engine. Returns: diff --git a/setup.py b/setup.py index c4b1821..22de9e7 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="chess-transformers", - version="0.2.0", + version="0.2.1", author="Sagar Vinodababu", author_email="sgrvinod@gmail.com", description="Chess Transformers", From ae0756015c6e7d1e90ad27493a6fb1af945f136c Mon Sep 17 00:00:00 2001 From: Sagar Vinodababu Date: Tue, 13 Feb 2024 16:36:12 -0700 Subject: [PATCH 2/2] removed gdown dependency --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 22de9e7..fec5026 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,6 @@ "torch==2.1.0", "tqdm==4.64.1", "scipy>=1.10.0", - "gdown==4.7.1" ], classifiers=[ "Development Status :: 3 - Alpha",