-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathlaunch_xla.py
68 lines (50 loc) · 2.25 KB
/
launch_xla.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.
From Ross Wightman's PyTorch Image Models, which can be found at
https://github.com/rwightman/pytorch-image-models/.
The original license can be found at this link:
https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
"""
import importlib
import os
import sys
from argparse import REMAINDER, ArgumentParser
import torch_xla.distributed.xla_multiprocessing as xmp
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")
# Optional arguments for the launch helper
parser.add_argument("--num-devices",
type=int,
default=1,
help="The number of XLA devices to use for distributed training")
# positional
parser.add_argument("script",
type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")
# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)
script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])
sys.argv = [args.script] + args.script_args
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
if __name__ == "__main__":
main()