forked from dgriff777/rl_a3c_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·103 lines (95 loc) · 2.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
os.environ["OMP_NUM_THREADS"] = "1"
import argparse
import copy
from types import SimpleNamespace
import torch
import torch.multiprocessing as mp
from environment import atari_env
from utils import read_config
from model import A3Clstm
from train import train
from test import test
from shared_optim import SharedRMSprop, SharedAdam
import time
from utils import weights_init
config = dict(
lr=float(1e-4),
gamma=float(0.99),
tau=float(1.0),
seed=int(1),
workers=int(10),
num_steps=int(20),
max_episode_length=int(1e4),
env='Breakout-v0',
env_config='config.json',
shared_optimizer=True,
load=False,
save_max=True,
optimizer='Adam',
load_model_dir='trained_models',
save_model_dir='trained_models',
log_dir='logs',
gpu_ids=[0, 1],
amsgrad=True,
skip_rate=int(4),
exp_name='exp_3',
interact_steps=int(1.5e5),
)
# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
# Training settings
# Implemented multiprocessing using locks but was not beneficial. Hogwild
# training was far superior
if __name__ == '__main__':
args = SimpleNamespace(**config)
torch.manual_seed(args.seed)
if args.gpu_ids == -1:
args.gpu_ids = [-1]
else:
torch.cuda.manual_seed(args.seed)
mp.set_start_method('spawn')
setup_json = read_config(args.env_config)
env_conf = setup_json["Default"]
for i in setup_json.keys():
if i in args.env:
env_conf = setup_json[i]
env = atari_env(args.env, env_conf, args)
shared_model = A3Clstm(env.observation_space.shape[0], env.action_space)
shared_model.apply(weights_init)
if args.load:
saved_state = torch.load(os.path.join(
args.load_model_dir, '{}-{}.dat'.format(args.env, args.exp_name)),
map_location=lambda storage, loc: storage)
shared_model.load_state_dict(saved_state)
shared_model.share_memory()
#
targ_shared = copy.deepcopy(shared_model)
targ_shared.share_memory()
if args.shared_optimizer:
if args.optimizer == 'RMSprop':
optimizer = SharedRMSprop(shared_model.parameters(), lr=args.lr)
if args.optimizer == 'Adam':
optimizer = SharedAdam(shared_model.parameters(),
lr=args.lr,
amsgrad=args.amsgrad)
optimizer.share_memory()
else:
optimizer = None
shared_counter = mp.Value('i', 0)
processes = []
p = mp.Process(target=test,
args=(args, shared_model, env_conf, shared_counter))
p.start()
processes.append(p)
time.sleep(0.1)
for rank in range(args.workers):
p = mp.Process(target=train,
args=(rank, args, shared_model, optimizer, env_conf,
shared_counter, targ_shared))
p.start()
processes.append(p)
time.sleep(0.1)
for p in processes:
time.sleep(0.1)
p.join()