Skip to content

Commit

Permalink
Merge pull request #8 from sgrvinod/0.2.1
Browse files Browse the repository at this point in the history
0.2.1
  • Loading branch information
sgrvinod authored Feb 13, 2024
2 parents 06d6c17 + ae07560 commit 1dcf110
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 38 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Change Log

## v0.2.1

### Changed

* All model checkpoints, datasets, and logs have now been moved to Microsoft Azure Storage for more reliable access.

## v0.2.0

### Added
Expand Down
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

<h1 align="center"><i>Chess Transformers</i></h1>
<p align="center"><i>Teaching transformers to play chess</i></p>
<p align="center"> <a href="https://github.com/sgrvinod/chess-transformers/releases/tag/v0.2.0"><img alt="Version" src="https://img.shields.io/github/v/tag/sgrvinod/chess-transformers?label=version"></a> <a href="https://github.com/sgrvinod/chess-transformers/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/sgrvinod/chess-transformers?label=license"></a></p>
<p align="center"> <a href="https://github.com/sgrvinod/chess-transformers/releases/tag/v0.2.1"><img alt="Version" src="https://img.shields.io/github/v/tag/sgrvinod/chess-transformers?label=version"></a> <a href="https://github.com/sgrvinod/chess-transformers/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/sgrvinod/chess-transformers?label=license"></a></p>
<br>

*Chess Transformers* is a library for training transformer models to play chess by learning from human games.
Expand Down Expand Up @@ -76,8 +76,8 @@ Detailed evaluation results for each model are provided below.

### *CT-E-20*

[**Configuration File**](chess_transformers/configs/models/CT-E-20.py) | [**Checkpoint**](https://drive.google.com/file/d/18Er4LbdujG-qiPPoqORvMQVcsiFerqY4/view?usp=drive_link) |
[**TensorBoard Logs**](https://drive.google.com/drive/folders/1WwWJS4804uKrONPcCtyQSGlOSUlo05-0?usp=drive_link)
[**Configuration File**](chess_transformers/configs/models/CT-E-20.py) | [**Checkpoint**](https://chesstransformers.blob.core.windows.net/checkpoints/CT-E-20/averaged_CT-E-20.pt) |
[**TensorBoard Logs**](https://chesstransformers.blob.core.windows.net/logs/CT-E-20.zip)

This is the encoder from the original transformer model in [*Vaswani et al. (2017)*](https://arxiv.org/abs/1706.03762) trained on the [*LE1222*](#le1222) dataset. A classification head at the **`turn`** token predicts the best half-move to be made (in UCI notation).

Expand Down Expand Up @@ -118,8 +118,8 @@ These evaluation games can be viewed [here](chess_transformers/eval/games/CT-E-2

### *CT-EFT-20*

[**Configuration File**](chess_transformers/configs/models/CT-EFT-20.py) | [**Checkpoint**](https://drive.google.com/file/d/1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD/view?usp=drive_link) |
[**TensorBoard Logs**](https://drive.google.com/drive/folders/1gD-msDgMlRqjB7Y0DIGWZxKgsxjqIwQp?usp=drive_link)
[**Configuration File**](chess_transformers/configs/models/CT-EFT-20.py) | [**Checkpoint**](https://chesstransformers.blob.core.windows.net/checkpoints/CT-EFT-20/averaged_CT-EFT-20.pt) |
[**TensorBoard Logs**](https://chesstransformers.blob.core.windows.net/logs/CT-EFT-20.zip)

This is the encoder from the original transformer model in [*Vaswani et al. (2017)*](https://arxiv.org/abs/1706.03762) trained on the [*LE1222*](#le1222) dataset. Two classification heads operate upon the encoder outputs at all chessboard squares to predict the best candidates for the source (*From*) and destination (*To*) squares that correspond to the best half-move to be made.

Expand Down Expand Up @@ -160,8 +160,8 @@ These evaluation games can be viewed [here](chess_transformers/eval/games/CT-EFT

### *CT-ED-45*

[**Configuration File**](chess_transformers/configs/models/CT-ED-45.py) | [**Checkpoint**](https://drive.google.com/file/d/1zasRpPmZQVtAqumet9XMy1FBpmxxiM4L/view?usp=drive_link) |
[**TensorBoard Logs**](https://drive.google.com/drive/folders/1LGsKMsjFRjQBS56UJZTjegFMl7amIbSw?usp=drive_link)
[**Configuration File**](chess_transformers/configs/models/CT-ED-45.py) | [**Checkpoint**](https://chesstransformers.blob.core.windows.net/checkpoints/CT-ED-45/averaged_CT-ED-45.pt) |
[**TensorBoard Logs**](https://chesstransformers.blob.core.windows.net/logs/CT-ED-45.zip)

This is the original transformer model (encoder *and* decoder) in [*Vaswani et al. (2017)*](https://arxiv.org/abs/1706.03762) trained on the [*LE1222*](#le1222) dataset. A classification head after the last decoder layer predicts a sequence of half-moves, starting with the best half-move to be made next, followed by the likely course of the game an arbitrary number of half-moves into the future.

Expand Down Expand Up @@ -219,7 +219,7 @@ On this data, we apply the following filters to keep only those games that:

These 274,794 games consist of a total **13,287,522 half-moves** made by the <ins>winners</ins> of the games, which alone constitute the dataset. For each such half-move, the chessboard, turn (white or black), and castling rights of both players before the move are calculated, as well as the sequence of half-moves beginning with this half-move up to 10 half-moves into the future. Draw potential is not calculated.

[**Download here.**](https://drive.google.com/drive/folders/17VrTNUbGXqCnK0d5oU3YJBsl6jybfVRM?usp=drive_link) The data is zipped and will need to be extracted.
[**Download here.**](https://chesstransformers.blob.core.windows.net/data/LE1222.zip) The data is zipped and will need to be extracted.

It consists of the following files:

Expand All @@ -244,7 +244,7 @@ On this data, we apply the following filters to keep only those games that:

These 2,751,394 games consist of a total **127,684,720 half-moves** made by the <ins>winners</ins> of the games, which alone constitute the dataset. For each such half-move, the chessboard, turn (white or black), and castling rights of both players before the move are calculated, as well as the sequence of half-moves beginning with this half-move up to 10 half-moves into the future. Draw potential is not calculated.

[**Download here.**](https://drive.google.com/drive/folders/17VrTNUbGXqCnK0d5oU3YJBsl6jybfVRM?usp=drive_link) The data is zipped and will need to be extracted.
[**Download here.**](https://chesstransformers.blob.core.windows.net/data/LE1222x.zip) The data is zipped and will need to be extracted.

It consists of the following files:

Expand Down
3 changes: 0 additions & 3 deletions chess_transformers/configs/models/CT-E-20.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@
FINAL_CHECKPOINT = (
"averaged_" + NAME + ".pt"
) # final checkpoint to be used for eval/inference
FINAL_CHECKPOINT_GDID = (
"18Er4LbdujG-qiPPoqORvMQVcsiFerqY4" # Google Drive ID for download
)

################################
########## Evaluation ##########
Expand Down
3 changes: 0 additions & 3 deletions chess_transformers/configs/models/CT-ED-45.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,6 @@
FINAL_CHECKPOINT = (
"averaged_" + NAME + ".pt"
) # final checkpoint to be used for eval/inference
FINAL_CHECKPOINT_GDID = (
"1zasRpPmZQVtAqumet9XMy1FBpmxxiM4L" # Google Drive ID for download
)

################################
########## Evaluation ##########
Expand Down
4 changes: 0 additions & 4 deletions chess_transformers/configs/models/CT-EFT-20.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@
FINAL_CHECKPOINT = (
"averaged_" + NAME + ".pt"
) # final checkpoint to be used for eval/inference
FINAL_CHECKPOINT_GDID = (
"1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD" # Google Drive ID for download
)


################################
########## Evaluation ##########
Expand Down
23 changes: 15 additions & 8 deletions chess_transformers/play/human_play.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -45,25 +45,32 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Cannot find model checkpoint on disk; will download.\n"
"Cannot find model checkpoint on disk; will download.\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD\n",
"To: /home/sgr/projects/chess-transformers/chess_transformers/checkpoints/CT-EFT-20/averaged_CT-EFT-20.pt\n",
"100%|██████████| 75.9M/75.9M [00:01<00:00, 73.0MB/s]\n"
"averaged_CT-EFT-20.pt: 100%|██████████| 72.4M/72.4M [00:18<00:00, 4.00MB/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Model loaded!\n",
"\n"
]
}
],
Expand All @@ -80,7 +87,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down
69 changes: 60 additions & 9 deletions chess_transformers/play/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import sys
import chess
import gdown
import urllib
import pathlib
import markdown
import textwrap
import chess.pgn
import chess.engine
import torch.utils.data
import torch.nn.functional as F
from tqdm import tqdm
from datetime import date
from tabulate import tabulate
from bs4 import BeautifulSoup
Expand Down Expand Up @@ -39,6 +40,36 @@
}


def download_file(url, output_path):
"""
Download the file at the given URL into a specified output path.
Adapted from https://github.com/tqdm/tqdm#hooks-and-callbacks
Args:
url (str): The URL for the file to download.
output_path (str): The path of the file to download into.
"""

class TQDMUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
return self.update(b * bsize - self.n)

with TQDMUpTo(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=url.split("/")[-1],
) as t:
urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
t.total = t.n


def load_model(CONFIG):
"""
Load model for inference.
Expand All @@ -62,10 +93,25 @@ def load_model(CONFIG):
checkpoint_folder.mkdir(parents=True, exist_ok=True)
checkpoint_path = checkpoint_folder / CONFIG.FINAL_CHECKPOINT
if not checkpoint_path.exists():
print("\nCannot find model checkpoint on disk; will download.")
gdown.download(
id=CONFIG.FINAL_CHECKPOINT_GDID, output=str(checkpoint_path), quiet=False
)
print("\nCannot find model checkpoint on disk; will download.\n")
download_file(
"ht"
+ "tp"
+ "s:"
+ "//"
+ "chess"
+ "transformers"
+ "."
+ "blob"
+ "."
+ "core"
+ "."
+ "windows"
+ "."
+ "net"
+ "/checkpoints/{}/{}".format(CONFIG.NAME, CONFIG.FINAL_CHECKPOINT),
str(checkpoint_path),
) # scramble address against simple bots

# Load checkpoint
checkpoint = torch.load(str(checkpoint_path))
Expand All @@ -80,6 +126,8 @@ def load_model(CONFIG):
)
model.eval() # eval mode disables dropout

print("\nModel loaded!\n")

return model


Expand All @@ -100,9 +148,11 @@ def show_engine_options(engine):
option.default,
option.min,
option.max,
"\n".join(textwrap.wrap(", ".join(option.var)))
if len(option.var) > 0
else None,
(
"\n".join(textwrap.wrap(", ".join(option.var)))
if len(option.var) > 0
else None
),
]
)
display(
Expand All @@ -123,7 +173,8 @@ def load_engine(path, show_options=True):
path (str): The path to the engine file (an executable).
show_options (bool, optional): Print configurable UCI options for engine.
show_options (bool, optional): Print configurable UCI options
for engine.
Returns:
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="chess-transformers",
version="0.2.0",
version="0.2.1",
author="Sagar Vinodababu",
author_email="[email protected]",
description="Chess Transformers",
Expand All @@ -30,7 +30,6 @@
"torch==2.1.0",
"tqdm==4.64.1",
"scipy>=1.10.0",
"gdown==4.7.1"
],
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down

0 comments on commit 1dcf110

Please sign in to comment.