This repo is intended as an easily-understandable, minimal reproduction of the LCZero training setup.
The goal is to eliminate complexity and expose the core net code to allow people to try different approaches and architectures, without needing to have a deep understanding of RL or the Leela codebase to begin.
Briefly, it's a reimplementation of DeepMind's AlphaGo Zero, applied to chess instead of Go. It's considered one of the top two chess engines in the world, and usually battles Stockfish for the #1 spot in computer chess competitions, which is a more classical engine.
Take a look at tf_net.py
and tf_layers.py
. Current Leela nets are all ResNets with Squeeze-Excitation blocks,
a reimplementation of which is provided here. This isn't a bad architecture for the problem by any means -
chess is played on a fixed 8x8 grid, and that lends itself well to convolutional networks.
A lot of the fancy tricks used in network training are usually some form of regularization or smuggling inductive bias in to reduce overfitting. That's because most real-world supervised learning problems are constrained by data. Leela is not! Training data is generated by self-play, and vast amounts of it exist. The goal of Leela training is to minimize training loss while keeping inference time as low as possible. The more accurate Leela's predictions, and the more quickly they can be generated, the better her search and therefore her performance.
This requires quite a different approach from the standard architectures from supervised learning competitions. Regularization methods like dropout, which are universal in most supervised computer vision nets, are not helpful in this regime. Also, the field of computer vision has changed substantially since the AlphaGo Zero paper, with vision transformers now massively ascendant. There's a huge amount of literature on boosting the performance of transformer models - can any of that be applied to make a significantly stronger Leela model?
The existing architecture can be seen in tf_net.py
. Briefly, it consists of an input reshape and convolution, followed
by a stack of residual convolution blocks that make up the bulk of the network, and finally three output heads. These
heads represent Leela's estimates for the policy (which move to make), the value (how good the current position
is) and the moves left (how much longer the game will last after this position). These three outputs
are used to guide Leela's search.
The code here should be simple enough to modify easily if you're familiar with TF/Keras. There's also a PyTorch
version, which may be slightly (but not too far) behind the TF one.
Note that the output heads like ConvolutionalPolicyHead
assume that you're passing them a channels-first tensor,
like you'd get in a convolutional network, with shape (batch, channels, width, height). If you're trying a totally
different architecture, you may prefer to try the other heads in tf_layers.py
like DensePolicyHead
instead.
These assume the input is a flat tensor of shape (batch, hidden_dim).
There's a standard dataset here.
Each tarball contains one day of training data from the test60
run. All data has been rescored and recompressed
with zstandard and is ready for training. As the data is quite large (~250GB), I suggest downloading and extracting only
a single file to start, and only downloading the whole set when your model is ready and you want to get accurate
benchmark results.
As for loading the data, the input pipeline is in new_data_pipeline.py
. You shouldn't need to modify it, but you can
look around if you like! For testing and benchmarking you can also run that script alone, to get sample inputs from the
pipeline without initializing any modules.
The pipeline can manage about 50k pos/s before the shuffle buffer becomes the bottleneck - this should be enough to saturate any single GPU even for a very small network. It might become a bottleneck if you want to train a very small net on a very fast GPU fleet, but that sounds like a weird thing to do.
Just run tf_train.py
! You'll need a dataset, either the standard dataset linked above, or people familiar
with Leela can use any directory of v6 chunks.
A sample invocation would look something like this:
python tf_train.py \
--dataset_path /path/to/extracted/tar/files/ \
--mixed_precision \
--num_workers 6 \
--batch_size 1024 \
--tensorboard_dir tensorboard_logs/ \
--save_dir 128x10 \
--reduce_lr_every_n_epochs 30
Just run pt_train.py
! It takes the same arguments as tf_train.py
, but training seems slower right now
in terms of both epoch speed and iterations to achieve the same loss. I suspect there are a couple of bugs
or layers that aren't quite equivalent in there still, but I'm working on it!
If you examine the TF code, you might notice that there's no L2 regularization, but there is a very weird hand-rolled constraint on the norms of several weight tensors. The reason for this is straightfoward: The primary purpose of L2 penalties in most neural nets is not regularization at all, especially for nets like Leela where training data is effectively infinite and overfitting is not a concern. Instead, the reason for using L2 penalties is to ensure that the weights in layers preceding batchnorms do not blow up. There's an excellent explanation for this phenomenon here.
Tuning weight decay can be tricky, however, and it's doubly tricky because standard L2 penalties fail to decay weights correctly when using the Adam optimizer. As a result, it's extremely easy to get an invisible weight blowup when training with Adam, which will silently decay your effective learning rate. To work around this problem, I add a maximum norm constraint to all weight tensors preceding batchnorms. The maximum norm is set to be the expected norm at initialization. This eliminates an annoying hyperparameter and performs better as well.
Here are some sample losses following training to convergence with the standard training set. You can use these to benchmark against when testing new architectures. Standard architectures are described as (num_filters)x(num_blocks)b.
Architecture | Policy Loss | Value loss | Moves left Loss | Total loss |
---|---|---|---|---|
128x10b | 1.8880 | 0.6773 | 0.3921 | 3.1677 |