This repo features a high-speed JAX implementation of the Proximal Policy Optimisation (PPO) algorithm. The algorithm is given as a single file implementation so that all the design choices are clear.
Training Reward | Training Reward |
---|---|
The algorithm is run by running the python script from the home directory. A custom config file can be given as follows,
python3 ppx/systems/ppo.py --config-name=ppo_MinAtar.yaml
Since hydra is used for managing configurations, overide parameters can be passed as arguments to this command. The default parameters can be changes in the relevant config file.
The notebooks/
directory contains simple .ipynb files to proide basic plotting functions.
We recommend managing dependencies using a virtual environment, which can be installed with the following commands,
python3.9 -m venv venv
source venv/bin/activate
Install dependencies using the requirements.txt file:
pip install -r requirements.txt
The codebase is installed as a pip package with the following command:
pip install -e .
Note JAX must be separately installed for the specific device used. For straightforward CPU usage use,
pip install -U "jax[cpu]"
In order to use JAX on your accelerators, you can find more details in the JAX documentation.
PPO-EWMA : a batch size-invariance algorithm that uses exponentially weighted moving averages to remove dependence on the batch-size hyperparameter.
- The next steps are tests with the learning rate adjustment, and advantage norm adjustment.
- The image below, shows the variance between batch sizes. The right image shows the current results using EWMA. The performance is slightly higher, but the variance greater.
PPO | PPO-EWMA |
---|---|
- Add an env wrapper to use the Jumanji style step method which returns a
state
andTimestep
. - Add a KL diverange PPO algorithm
- Add tests with different learning rates. Include the effect of learning rate annealing.
The code is based on the format of Mava and is inspired from PureJaxRL and CleanRL.