Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Apr 4, 2024
1 parent a8fe037 commit 822dcb1
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 11 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ After the tactic generator is trained, we combine it with best-first search to p

For models without retrieval, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-cpus 8 --with-gpus
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-cpus 8 --with-gpus
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 8 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 8 --num-gpus 1
```

For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
Expand Down
17 changes: 15 additions & 2 deletions generator/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ trainer:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
verbose: true
save_top_k: 1
save_last: true
monitor: Pass@1_val
mode: max
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: Pass@1_val
patience: 2
mode: max
verbose: true

model:
model_name: google/byt5-small
Expand All @@ -26,9 +39,9 @@ model:
length_penalty: 0.0
ret_ckpt_path: null
eval_num_retrieved: 100
eval_num_workers: 8
eval_num_workers: 6 # Lower this number if you don't have 80GB GPU memory.
eval_num_gpus: 1
eval_num_theorems: 400
eval_num_theorems: 300 # Lower this number will make validation faster (but noiser).

data:
data_path: data/leandojo_benchmark_4/novel_premises/
Expand Down
17 changes: 15 additions & 2 deletions generator/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ trainer:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
verbose: true
save_top_k: 1
save_last: true
monitor: Pass@1_val
mode: max
- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: Pass@1_val
patience: 2
mode: max
verbose: true

model:
model_name: google/byt5-small
Expand All @@ -26,9 +39,9 @@ model:
length_penalty: 0.0
ret_ckpt_path: null
eval_num_retrieved: 100
eval_num_workers: 8
eval_num_workers: 6 # Lower this number if you don't have 80GB GPU memory.
eval_num_gpus: 1
eval_num_theorems: 400
eval_num_theorems: 300 # Lower this number will make validation faster (but noiser).

data:
data_path: data/leandojo_benchmark_4/random/
Expand Down
1 change: 1 addition & 0 deletions prover/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def main() -> None:
args = parser.parse_args()

assert args.ckpt_path or args.tactic
assert args.num_gpus <= args.num_workers

logger.info(f"PID: {os.getpid()}")
logger.info(args)
Expand Down
6 changes: 3 additions & 3 deletions prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def search(
with torch.no_grad():
try:
self._best_first_search()
except DojoCrashError:
logger.warning(f"Dojo crashed when proving {thm}")
except DojoCrashError as ex:
logger.warning(f"Dojo crashed with {ex} when proving {thm}")
pass

if self.root.status == Status.PROVED:
Expand Down Expand Up @@ -389,7 +389,7 @@ def __init__(
if ckpt_path is None:
tac_gen = FixedTacticGenerator(tactic, module)
else:
device = torch.device("cuda") if with_gpus else torch.device("cpu")
device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu")
tac_gen = RetrievalAugmentedGenerator.load(
ckpt_path, device=device, freeze=True
)
Expand Down
4 changes: 2 additions & 2 deletions scripts/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


LEANDOJO_BENCHMARK_4_URL = (
"https://zenodo.org/records/10823489/files/leandojo_benchmark_4.tar.gz"
"https://zenodo.org/records/10929138/files/leandojo_benchmark_4.tar.gz?download=1"
)
DOWNLOADS = {
LEANDOJO_BENCHMARK_4_URL: "c45383c1a94b0ab17395401fc8b03f36",
LEANDOJO_BENCHMARK_4_URL: "84a75ce552b31731165d55542b1aaca9",
}


Expand Down

0 comments on commit 822dcb1

Please sign in to comment.