-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdqnTest.py
87 lines (64 loc) · 2.68 KB
/
dqnTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Conv2D, MaxPooling2D, Activation, Flatten
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
import numpy as np
from collections import deque
import time
REPLAY_MEMORY_SIZE = 50_000
MODEL_NAME = "256x2"
# Own Tensorboard class
class ModifiedTensorBoard(TensorBoard):
# Overriding init to set initial step and writer (we want one log file for all .fit() calls)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.step = 1
self.writer = tf.summary.FileWriter(self.log_dir)
# Overriding this method to stop creating default log writer
def set_model(self, model):
pass
# Overrided, saves logs with our step number
# (otherwise every .fit() will start writing from 0th step)
def on_epoch_end(self, epoch, logs=None):
self.update_stats(**logs)
# Overrided
# We train for one batch only, no need to save anything at epoch end
def on_batch_end(self, batch, logs=None):
pass
# Overrided, so won't close writer
def on_train_end(self, _):
pass
# Custom method for saving own metrics
# Creates writer, writes custom metrics and closes writer
def update_stats(self, **stats):
self._write_logs(stats, self.step)
class DQNAgent:
def __init__(self):
# main model, gets trained every step
self.model = self.create_model()
# Target model, this what we .predict against every step
self.target_model = self.create_model()
self.taget_model.set_weights(self.model.get_weights())
self.replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE) # an array with a maximum length
self.tensorboard = ModifiedTensorBoard(log_dir=f"logs/{MODEL_NAME}-{int(time.time())}")
self.target_update_counter = 0
def create_model(self):
model = Sequential()
model.add(Conv2D(256, (3, 3), input_shape=env.OBSERVATION_SPACE_VALUES))
model.add(Activation("relu"))
model.add(MaxPooling2D(2, 2))
model.add(Dropout(0.2))
model.add(Conv2D(256, (3, 3)))
model.add(Activation("relu"))
model.add(MaxPooling2D(2, 2))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(64))
model.add(Dense(env.ACTION_SPACE_SIZE, activation="linear"))
model.compile(loss="mse", optimizer=Adam(lr=0.001), metrics=['accuracy'])
return model
def update_replay_memory(self, transition):
self.replay_memory.append(transition)
def get_qs(self, state):
return self.model_predict(np.array(state).reshape(-1, *state.shape)/255)[0]