forked from seungeunrho/minimalRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdqn.py
111 lines (90 loc) · 3.26 KB
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gym
import collections
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#Hyperparameters
learning_rate = 0.0005
gamma = 0.98
buffer_limit = 50000
batch_size = 32
class ReplayBuffer():
def __init__(self):
self.buffer = collections.deque(maxlen=buffer_limit)
def put(self, transition):
self.buffer.append(transition)
def sample(self, n):
return random.sample(self.buffer, n)
def size(self):
return len(self.buffer)
class Qnet(nn.Module):
def __init__(self):
super(Qnet, self).__init__()
self.fc1 = nn.Linear(4, 256)
self.fc2 = nn.Linear(256, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def sample_action(self, obs, epsilon):
out = self.forward(obs)
coin = random.random()
if coin < epsilon:
return random.randint(0,1)
else :
return out.argmax().item()
def train(q, q_target, memory, optimizer):
for i in range(10):
mini_batch = memory.sample(batch_size)
s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []
for transition in mini_batch:
s, a, r, s_prime, done_mask = transition
s_lst.append(s)
a_lst.append([a])
r_lst.append([r])
s_prime_lst.append(s_prime)
done_mask_lst.append([done_mask])
s,a,r,s_prime,done_mask = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
torch.tensor(done_mask_lst)
q_out = q(s)
q_a = q_out.gather(1,a)
max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
target = r + gamma * max_q_prime * done_mask
loss = F.smooth_l1_loss(q_a, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def main():
env = gym.make('CartPole-v1')
q = Qnet()
q_target = Qnet()
q_target.load_state_dict(q.state_dict())
memory = ReplayBuffer()
print_interval = 20
score = 0.0
optimizer = optim.Adam(q.parameters(), lr=learning_rate)
for n_epi in range(10000):
epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
s = env.reset()
for t in range(600):
a = q.sample_action(torch.from_numpy(s).float(), epsilon)
s_prime, r, done, info = env.step(a)
done_mask = 0.0 if done else 1.0
memory.put((s,a,r/100.0,s_prime, done_mask))
s = s_prime
score += r
if done:
break
if memory.size()>2000:
train(q, q_target, memory, optimizer)
if n_epi%print_interval==0 and n_epi!=0:
q_target.load_state_dict(q.state_dict())
print("# of episode :{}, avg score : {:.1f}, buffer size : {}, epsilon : {:.1f}%".format(
n_epi, score/print_interval, memory.size(), epsilon*100))
score = 0.0
env.close()
if __name__ == '__main__':
main()