Skip to content

Commit

Permalink
running
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Mar 30, 2024
1 parent 84e07a1 commit a8fe037
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
7 changes: 4 additions & 3 deletions generator/confs/cli_lean4_novel_premises.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trainer:
cpu_checkpointing: false
gradient_clip_val: 1.0
max_steps: 500000
check_val_every_n_epoch: 2
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
Expand All @@ -26,8 +26,9 @@ model:
length_penalty: 0.0
ret_ckpt_path: null
eval_num_retrieved: 100
eval_num_cpus: 12
eval_num_theorems: 0
eval_num_workers: 8
eval_num_gpus: 1
eval_num_theorems: 400

data:
data_path: data/leandojo_benchmark_4/novel_premises/
Expand Down
7 changes: 4 additions & 3 deletions generator/confs/cli_lean4_random.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trainer:
cpu_checkpointing: false
gradient_clip_val: 1.0
max_steps: 500000
check_val_every_n_epoch: 2
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
Expand All @@ -26,8 +26,9 @@ model:
length_penalty: 0.0
ret_ckpt_path: null
eval_num_retrieved: 100
eval_num_cpus: 12
eval_num_theorems: 0
eval_num_workers: 8
eval_num_gpus: 1
eval_num_theorems: 400

data:
data_path: data/leandojo_benchmark_4/random/
Expand Down
13 changes: 9 additions & 4 deletions generator/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def __init__(
warmup_steps: int,
num_beams: int,
eval_num_retrieved: int,
eval_num_cpus: int,
eval_num_workers: int,
eval_num_gpus: int,
eval_num_theorems: int,
max_inp_seq_len: int,
max_oup_seq_len: int,
Expand All @@ -97,7 +98,8 @@ def __init__(
self.num_beams = num_beams
self.length_penalty = length_penalty
self.eval_num_retrieved = eval_num_retrieved
self.eval_num_cpus = eval_num_cpus
self.eval_num_workers = eval_num_workers
self.eval_num_gpus = eval_num_gpus
self.eval_num_theorems = eval_num_theorems
self.max_inp_seq_len = max_inp_seq_len
self.max_oup_seq_len = max_oup_seq_len
Expand Down Expand Up @@ -244,12 +246,14 @@ def on_validation_epoch_end(self) -> None:
ckpt_path = f"{self.trainer.log_dir}/checkpoints/last.ckpt"
self.trainer.save_checkpoint(ckpt_path)
logger.info(f"Saved checkpoint to {ckpt_path}. Evaluating...")
torch.cuda.empty_cache()

data_path = self.trainer.datamodule.data_path
if self.retriever is None:
acc = evaluate(
data_path=data_path,
num_cpus=self.eval_num_cpus,
num_workers=self.eval_num_workers,
num_gpus=self.eval_num_gpus,
num_theorems=self.eval_num_theorems,
ckpt_path=ckpt_path,
)
Expand All @@ -264,7 +268,8 @@ def on_validation_epoch_end(self) -> None:
)
acc = evaluate(
data_path=data_path,
num_cpus=self.eval_num_cpus,
num_workers=self.eval_num_workers,
num_gpus=self.eval_num_gpus,
num_theorems=self.eval_num_theorems,
ckpt_path=ckpt_path,
indexed_corpus_path=corpus_path,
Expand Down
1 change: 0 additions & 1 deletion prover/proof_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ def __init__(
)
return

ray.init()
if num_gpus >= 1:
logger.info(f"Launching {num_workers} workers with {num_gpus} GPUs.")
num_gpus_per_worker = num_gpus / num_workers
Expand Down

0 comments on commit a8fe037

Please sign in to comment.