forked from YangRui2015/2048_env
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_dqn.py
146 lines (113 loc) · 4.65 KB
/
main_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from ppo_agent import PPO
from dqn_agent import DQN
from gym_2048 import Game2048Env
import torch
import numpy as np
import time
import logger
from utils import log2_shaping, Perfomance_Saver, Model_Saver
train_episodes = 20000
test_episodes = 50
ifrender = False
eval_interval = 25
epsilon_decay_interval = 100
log_interval = 5
def train():
episodes = train_episodes
logger.configure(dir="./log/", format_strs="stdout,tensorboard,log")
agent = DQN(num_state=16, num_action=4)
env = Game2048Env()
pf_saver = Perfomance_Saver()
model_saver = Model_Saver(num=10)
eval_max_score = 0
for i in range(episodes):
state, reward, done, info = env.reset()
state = log2_shaping(state)
start = time.time()
loss = None
while True:
if agent.buffer.memory_counter <= agent.memory_capacity:
action = agent.select_action(state, random=True)
else:
action = agent.select_action(state)
next_state, reward, done, info = env.step(action)
next_state = log2_shaping(next_state)
reward = log2_shaping(reward, divide=1)
agent.store_transition(state, action, reward, next_state)
state = next_state
if ifrender:
env.render()
if agent.buffer.memory_counter % agent.train_interval == 0 and agent.buffer.memory_counter > agent.memory_capacity: # 相当于填满后才update
loss = agent.update()
if done:
if i % log_interval == 0:
if loss:
logger.logkv('loss', loss)
logger.logkv('training progress', (i+1) / episodes)
logger.logkv('episode reward', info['score'])
logger.logkv('episode steps', info['steps'])
logger.logkv('highest', info['highest'])
logger.logkv('epsilon', agent.epsilon)
logger.dumpkvs()
loss = None
if i % epsilon_decay_interval == 0: # episilon decay
agent.epsilon_decay(i, episodes)
break
end = time.time()
print('episode time:{} s\n'.format(end - start))
# eval
if i % eval_interval == 0 and i:
eval_info = test(episodes=test_episodes, agent=agent)
average_score, max_score, score_lis = eval_info['mean'], eval_info['max'], eval_info['list']
pf_saver.save(score_lis, info=f'episode:{i}')
if int(average_score) > eval_max_score:
eval_max_score = int(average_score)
name = 'dqn_{}.pkl'.format(int(eval_max_score))
agent.save(name=name)
model_saver.save("./save/" + name)
logger.logkv('eval average score', average_score)
logger.logkv('eval max socre', max_score)
logger.dumpkvs()
def test(episodes=20, agent=None, load_path=None, ifrender=False, log=False):
if log:
logger.configure(dir="./log/", format_strs="stdout")
if agent is None:
agent = DQN(num_state=16, num_action=4)
if load_path:
agent.load(load_path)
else:
agent.load()
env = Game2048Env()
score_list = []
highest_list = []
for i in range(episodes):
state, _, done, info = env.reset()
state = log2_shaping(state)
start = time.time()
while True:
action = agent.select_action(state, deterministic=True)
next_state, _, done, info = env.step(action)
next_state = log2_shaping(next_state)
state = next_state
if ifrender:
env.render()
if done:
if log:
logger.logkv('episode number', i + 1)
logger.logkv('episode reward', info['score'])
logger.logkv('episode steps', info['steps'])
logger.logkv('highest', info['highest'])
logger.dumpkvs()
break
end = time.time()
if log:
print('episode time:{} s\n'.format(end - start))
score_list.append(info['score'])
highest_list.append(info['highest'])
print('mean score:{}, mean highest:{}'.format(np.mean(score_list), np.mean(highest_list)))
print('max score:{}, max hightest:{}'.format(np.max(score_list), np.max(highest_list)))
result_info = {'mean':np.mean(score_list), 'max':np.max(score_list), 'list':score_list}
return result_info
if __name__ == "__main__":
# test(episodes=test_episodes, ifrender=ifrender)
train()