-
Notifications
You must be signed in to change notification settings - Fork 99
/
train.py
89 lines (73 loc) · 3.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json
import argparse
from shutil import copyfile
import torch
import torch.multiprocessing as mp
from core.trainer import Trainer
from core.dist import (
get_world_size,
get_local_rank,
get_global_rank,
get_master_ip,
)
parser = argparse.ArgumentParser(description='E2FGVI')
parser.add_argument('-c',
'--config',
default='configs/train_e2fgvi.json',
type=str)
parser.add_argument('-p', '--port', default='23455', type=str)
args = parser.parse_args()
def main_worker(rank, config):
if 'local_rank' not in config:
config['local_rank'] = config['global_rank'] = rank
if config['distributed']:
torch.cuda.set_device(int(config['local_rank']))
torch.distributed.init_process_group(backend='nccl',
init_method=config['init_method'],
world_size=config['world_size'],
rank=config['global_rank'],
group_name='mtorch')
print('using GPU {}-{} for training'.format(int(config['global_rank']),
int(config['local_rank'])))
config['save_dir'] = os.path.join(
config['save_dir'],
'{}_{}'.format(config['model']['net'],
os.path.basename(args.config).split('.')[0]))
config['save_metric_dir'] = os.path.join(
'./scores',
'{}_{}'.format(config['model']['net'],
os.path.basename(args.config).split('.')[0]))
if torch.cuda.is_available():
config['device'] = torch.device("cuda:{}".format(config['local_rank']))
else:
config['device'] = 'cpu'
if (not config['distributed']) or config['global_rank'] == 0:
os.makedirs(config['save_dir'], exist_ok=True)
os.makedirs(config['save_metric_dir'], exist_ok=True)
config_path = os.path.join(config['save_dir'],
args.config.split('/')[-1])
if not os.path.isfile(config_path):
copyfile(args.config, config_path)
print('[**] create folder {}'.format(config['save_dir']))
trainer = Trainer(config)
trainer.train()
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
mp.set_sharing_strategy('file_system')
# loading configs
config = json.load(open(args.config))
# setting distributed configurations
config['world_size'] = get_world_size()
config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
config['distributed'] = True if config['world_size'] > 1 else False
print(config['world_size'])
# setup distributed parallel training environments
if get_master_ip() == "127.0.0.1":
# manually launch distributed processes
mp.spawn(main_worker, nprocs=config['world_size'], args=(config, ))
else:
# multiple processes have been launched by openmpi
config['local_rank'] = get_local_rank()
config['global_rank'] = get_global_rank()
main_worker(-1, config)