-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsweep_slurm.py
59 lines (46 loc) · 1.45 KB
/
sweep_slurm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import json
import os
import subprocess
from pathlib import Path
import yaml
import wandb
# Set API key
if Path("keys.json").is_file():
with open("keys.json") as file:
api_key = json.load(file)["wandb_key"]
os.environ["WANDB_API_KEY"] = api_key
# Gather nodes allocated to current slurm job
result = subprocess.run(["scontrol", "show", "hostnames"], stdout=subprocess.PIPE)
node_list = result.stdout.decode("utf-8").split("\n")[:-1]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("sweep_config", type=str)
parser.add_argument("train_script", type=str)
parser.add_argument("project", type=str)
args = parser.parse_args()
wandb.init(project=args.project)
with open(args.sweep_config) as file:
config_dict = yaml.load(file, Loader=yaml.FullLoader)
config_dict["program"] = args.train_script
sweep_id = wandb.sweep(config_dict, project=args.project)
sp = []
for node in node_list:
sp.append(
subprocess.Popen(
[
"srun",
"--nodes=1",
"--ntasks=1",
"-w",
node,
"start-agent.sh",
sweep_id,
args.project,
]
)
)
exit_codes = [p.wait() for p in sp] # wait for processes to finish
return exit_codes
if __name__ == "__main__":
main()