Skip to content

Commit

Permalink
Merge pull request #63 from lean-dojo/dev
Browse files Browse the repository at this point in the history
Update Checkpoints Format
  • Loading branch information
yangky11 authored Jul 16, 2024
2 parents 00c3fb9 + e643d94 commit 87c13fd
Show file tree
Hide file tree
Showing 22 changed files with 972 additions and 813 deletions.
63 changes: 39 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,18 @@ Check out [Lean Copilot](https://github.com/lean-dojo/LeanCopilot) if you want t
1. Download and install [Miniconda Python 3](https://docs.conda.io/en/latest/miniconda.html) (Anaconda should also work).
2. Create the conda environment and install Python dependencies:
```bash
conda create --yes --name ReProver python=3.10 ipython numpy
conda create --yes --name ReProver python=3.11 ipython
conda activate ReProver
pip install torch --index-url https://download.pytorch.org/whl/cu121 # Depending on your CUDA version; see https://pytorch.org/.
pip install tqdm loguru deepspeed "pytorch-lightning[extra]" transformers tensorboard openai rank_bm25 lean-dojo
pip install torch # Depending on your CUDA version; see https://pytorch.org/.
pip install tqdm loguru deepspeed "pytorch-lightning[extra]" transformers wandb openai rank_bm25 lean-dojo vllm
pip install git+https://github.com/pytorch/torchtune
```
3. Prepend the repo's root to the `PYTHONPATH` environment variable.
4. Make sure `wget` and `tar` are available. Then, run `python scripts/download_data.py` to download [LeanDojo Benchmark 4](https://zenodo.org/doi/10.5281/zenodo.8040109). They will be saved to `./data`.
5. Satisfy the requirements of [LeanDojo](https://github.com/lean-dojo/LeanDojo#requirements).
6. Use [LeanDojo](https://github.com/lean-dojo/LeanDojo) to trace all repos in the datasets: `python scripts/trace_repos.py`. This step may take some time. Please refer to [LeanDojo's documentation](https://leandojo.readthedocs.io/en/latest/) if you encounter any issues.
7. Run `wandb login` to log in Weights & Biases.



## Premise Selection
Expand All @@ -256,28 +259,29 @@ The config files for our experiments are in [./retrieval/confs](./retrieval/conf

Run `python retrieval/main.py fit --help` to see how to use the training script. For example:
```bash
python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml # Train the retriever on the `random` split of LeanDojo Benchmark 4.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_novel_premises.yaml # Train the retriever on the `novel_premises` split of LeanDojo Benchmark 4.
mkdir logs # Create the directory for training logs.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml --trainer.logger.name train_retriever_random --trainer.logger.save_dir logs/train_retriever_random # Train the retriever on the `random` split of LeanDojo Benchmark 4.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_novel_premises.yaml --trainer.logger.name train_retriever_novel_premises --trainer.logger.save_dir logs/train_retriever_novel_premises # Train the retriever on the `novel_premises` split of LeanDojo Benchmark 4.
```
The training script saves hyperparameters, model checkpoints, and other information to `./lightning_logs/EXP_ID/`, where `EXP_ID` is an arbitrary experiment ID that will be printed by the training script.
Hyperparameters and model checkpoints are saved in `./logs/train_retriever_*`, and you can monitor the training process on Weights & Biases.


### Retrieving Premises for All Proof States

After the models are trained, run the following commands to retrieve premises for all proof states in the dataset.
```bash
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_random --trainer.logger.save_dir logs/predict_retriever_random
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_novel_premises --trainer.logger.save_dir logs/predict_retriever_novel_premises
```
Retrieved premises are saved to `./lightning_logs/EXP_ID'/predictions.pickle`.
, where `PATH_TO_RETRIEVER_CHECKPOINT` is the model checkpoint produced in the previous step. Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`.


### Evaluating the Retrieved Premises

After predictions are saved, evaluate them using metrics such as R@1, R@10, and MRR.
```bash
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/random --preds-file PATH_TO_PREDICTIONS_PICKLE
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises --preds-file PATH_TO_PREDICTIONS_PICKLE
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/random --preds-file logs/predict_retriever_random/predictions.pickle
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises --preds-file logs/predict_retriever_novel_premises/predictions.pickle
```


Expand All @@ -286,40 +290,51 @@ python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premise

### Training the Tactic Generator

Similar to premise selection, you can run `python generator/main.py --help` and `python generator/main.py fit --help` to check the command line options.
Similar to premise selection, you can run `python generation/main.py --help` and `python generation/main.py fit --help` to check the command line options.

To train tactic generators without retrieval:
```bash
python generator/main.py fit --config generator/confs/cli_lean4_random.yaml # LeanDojo Benchmark 4, `random` split
python generator/main.py fit --config generator/confs/cli_lean4_novel_premises.yaml # LeanDojo Benchmark 4, `novel_premises` split
python generation/main.py fit --config generation/confs/cli_lean4_random.yaml --trainer.logger.name train_generator_random --trainer.logger.save_dir logs/train_generator_random # LeanDojo Benchmark 4, `random` split
python generation/main.py fit --config generation/confs/cli_lean4_novel_premises.yaml --trainer.logger.name train_generator_novel_premises --trainer.logger.save_dir logs/train_generator_novel_premises # LeanDojo Benchmark 4, `novel_premises` split
```
Hyperparameters and model checkpoints are saved in `./logs/train_generator_*`, and you can monitor the training process on Weights & Biases.

To train models augmented by retrieval, we need to provide a retriever checkpoint and its predictions on all proof states in the dataset:
```bash
python generator/main.py fit --config generator/confs/cli_lean4_random.yaml --model.ret_ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path PATH_TO_PREDICTIONS_PICKLE
python generator/main.py fit --config generator/confs/cli_lean4_novel_premises.yaml --model.ret_ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path PATH_TO_PREDICTIONS_PICKLE
python generation/main.py fit --config generation/confs/cli_lean4_random.yaml --model.ret_ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path logs/predict_retriever_random/predictions.pickle --trainer.logger.name train_reprover_random --trainer.logger.save_dir logs/train_reprover_random
python generation/main.py fit --config generation/confs/cli_lean4_novel_premises.yaml --model.ret_ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path logs/predict_retriever_novel_premises/predictions.pickle --trainer.logger.name train_reprover_novel_premises --trainer.logger.save_dir logs/train_reprover_novel_premises
```


### Theorem Proving Evaluation on LeanDojo Benchmark

After the tactic generator is trained, we combine it with best-first search to prove theorems by interacting with Lean.

For models without retrieval, run:
The evaluation script takes Hugging Face model checkpoints (either local or remote) as input. For remote models, you can simply use their names, e.g., [kaiyuy/leandojo-lean4-tacgen-byt5-small](https://huggingface.co/kaiyuy/leandojo-lean4-tacgen-byt5-small). For locally trained models, you first need to convert them from PyTorch Ligthning checkpoints to Hugging Face checkpoints:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python scripts/convert_checkpoint.py generator --src $PATH_TO_GENERATOR_CHECKPOINT --dst ./leandojo-lean4-tacgen-byt5-small
python scripts/convert_checkpoint.py retriever --src $PATH_TO_RETRIEVER_CHECKPOINT --dst ./leandojo-lean4-retriever-byt5-small
```
, where `PATH_TO_GENERATOR_CHECKPOINT` and `PATH_TO_RETRIEVER_CHECKPOINT` are PyTorch Ligthning checkpoints produced by the training script.


To evaluate the model without retrieval, run (using the `random` data split as example):
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-tacgen-byt5-small --split test --num-workers 5 --num-gpus 1
```
You may tweak `--num-workers` and `--num-gpus` to fit your hardware.


For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
For the model with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
```bash
python retrieval/index.py --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path PATH_TO_INDEXED_CORPUS
# Do it separately for two data splits.
python retrieval/index.py --ckpt_path ./leandojo-lean4-retriever-byt5-small --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS
```
It saves the indexed corpurs as a pickle file to `PATH_TO_INDEXED_CORPUS`.

Then, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_REPROVER_CHECKPOINT --indexed-corpus-path PATH_TO_INDEXED_CORPUS --split test --num-cpus 8 --with-gpus
# Do it separately for two data splits.
python scripts/convert_checkpoint.py generator --src $PATH_TO_REPROVER_CHECKPOINT --dst ./leandojo-lean4-retriever-tacgen-byt5-small
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-retriever-tacgen-byt5-small --ret_ckpt_path ./leandojo-lean4-retriever-byt5-small --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-workers 5 --num-gpus 1
```


Expand Down
68 changes: 9 additions & 59 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
)
from transformers import get_cosine_schedule_with_warmup
from transformers import get_constant_schedule_with_warmup
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from typing import Optional, List, Dict, Any, Tuple, Generator
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
Expand Down Expand Up @@ -354,48 +354,14 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]:
return list(all_pos_premises)


_SPACES_REGEX = re.compile(r"\s+", re.DOTALL)


def normalize_spaces(s: str) -> str:
"""Repalce any consecutive block of whitespace characters in ``s`` with a single whitespace."""
return _SPACES_REGEX.sub(" ", s).strip()


def format_tactic(annot_tac: str, provenances, normalize: bool) -> str:
"""Use full names for the all <a>...</a>."""
if normalize:
annot_tac = normalize_spaces(annot_tac)
if len(provenances) == 0:
return annot_tac

tac = ""
marks = list(re.finditer(r"<a>(?P<ident>.+?)</a>", annot_tac))

for i, (m, prov) in enumerate(zip_strict(marks, provenances)):
last_end = marks[i - 1].end() if i > 0 else 0
tac += annot_tac[last_end : m.start()] + "<a>" + prov["full_name"] + "</a>"

tac += annot_tac[marks[-1].end() :]
return tac


def format_state(s: str) -> str:
m = re.match(r"\d+ goals", s)
if m is not None:
return s[m.end() :].strip()
else:
return s


def format_augmented_state(
s: str, premises: List[Premise], max_len: int, p_drop: float
s: str, premises: List[Premise], max_len: Optional[int] = None, p_drop: float = 0.0
) -> str:
"""Format a state with retrieved premises and drop some of them with probability ``p_drop``."""
s = format_state(s)

aug_s = ""
length = 0
if max_len is None:
max_len = 9999999999999999999999
max_premises_len = max_len - len(bytes(s.encode("utf-8")))

for p in premises:
Expand Down Expand Up @@ -429,22 +395,7 @@ def get_optimizers(
logger.info("Optimizing with AdamW")
optimizer = torch.optim.AdamW(parameters, lr=lr)

if trainer.max_steps != -1:
max_steps = trainer.max_steps
else:
assert trainer.max_epochs is not None
max_steps = (
trainer.max_epochs
* len(trainer.datamodule.train_dataloader())
// trainer.accumulate_grad_batches
)

scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
)

scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps)
return {
"optimizer": optimizer,
"lr_scheduler": {
Expand All @@ -462,14 +413,13 @@ def _is_deepspeed_checkpoint(path: str):

def load_checkpoint(model_cls, ckpt_path: str, device, freeze: bool):
"""Handle DeepSpeed checkpoints in model loading."""
if not _is_deepspeed_checkpoint(ckpt_path):
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
else:
if _is_deepspeed_checkpoint(ckpt_path):
with tempfile.TemporaryDirectory() as dirname:
path = os.path.join(dirname, "lightning.cpkt")
convert_zero_checkpoint_to_fp32_state_dict(ckpt_path, path)
model = model_cls.load_from_checkpoint(path, strict=False)
model = model.to(device)
model = model_cls.load_from_checkpoint(path, strict=False).to(device)
else: # PyTorch Ligthning checkpoints
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
if freeze:
model.freeze()
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
gradient_clip_val: 1.0
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
Expand Down Expand Up @@ -46,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/novel_premises/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
gradient_clip_val: 1.0
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
Expand Down Expand Up @@ -46,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/random/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
Loading

0 comments on commit 87c13fd

Please sign in to comment.