-
Notifications
You must be signed in to change notification settings - Fork 978
/
Copy pathpolicy_value_net_keras.py
111 lines (91 loc) · 4.78 KB
/
policy_value_net_keras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# -*- coding: utf-8 -*-
"""
An implementation of the policyValueNet with Keras
Tested under Keras 2.0.5 with tensorflow-gpu 1.2.1 as backend
@author: Mingxu Zhang
"""
from __future__ import print_function
from keras.engine.topology import Input
from keras.engine.training import Model
from keras.layers.convolutional import Conv2D
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.merge import Add
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l2
from keras.optimizers import Adam
import keras.backend as K
from keras.utils import np_utils
import numpy as np
import pickle
class PolicyValueNet():
"""policy-value network """
def __init__(self, board_width, board_height, model_file=None):
self.board_width = board_width
self.board_height = board_height
self.l2_const = 1e-4 # coef of l2 penalty
self.create_policy_value_net()
self._loss_train_op()
if model_file:
net_params = pickle.load(open(model_file, 'rb'))
self.model.set_weights(net_params)
def create_policy_value_net(self):
"""create the policy value network """
in_x = network = Input((4, self.board_width, self.board_height))
# conv layers
network = Conv2D(filters=32, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
network = Conv2D(filters=64, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
network = Conv2D(filters=128, kernel_size=(3, 3), padding="same", data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
# action policy layers
policy_net = Conv2D(filters=4, kernel_size=(1, 1), data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
policy_net = Flatten()(policy_net)
self.policy_net = Dense(self.board_width*self.board_height, activation="softmax", kernel_regularizer=l2(self.l2_const))(policy_net)
# state value layers
value_net = Conv2D(filters=2, kernel_size=(1, 1), data_format="channels_first", activation="relu", kernel_regularizer=l2(self.l2_const))(network)
value_net = Flatten()(value_net)
value_net = Dense(64, kernel_regularizer=l2(self.l2_const))(value_net)
self.value_net = Dense(1, activation="tanh", kernel_regularizer=l2(self.l2_const))(value_net)
self.model = Model(in_x, [self.policy_net, self.value_net])
def policy_value(state_input):
state_input_union = np.array(state_input)
results = self.model.predict_on_batch(state_input_union)
return results
self.policy_value = policy_value
def policy_value_fn(self, board):
"""
input: board
output: a list of (action, probability) tuples for each available action and the score of the board state
"""
legal_positions = board.availables
current_state = board.current_state()
act_probs, value = self.policy_value(current_state.reshape(-1, 4, self.board_width, self.board_height))
act_probs = zip(legal_positions, act_probs.flatten()[legal_positions])
return act_probs, value[0][0]
def _loss_train_op(self):
"""
Three loss terms:
loss = (z - v)^2 + pi^T * log(p) + c||theta||^2
"""
# get the train op
opt = Adam()
losses = ['categorical_crossentropy', 'mean_squared_error']
self.model.compile(optimizer=opt, loss=losses)
def self_entropy(probs):
return -np.mean(np.sum(probs * np.log(probs + 1e-10), axis=1))
def train_step(state_input, mcts_probs, winner, learning_rate):
state_input_union = np.array(state_input)
mcts_probs_union = np.array(mcts_probs)
winner_union = np.array(winner)
loss = self.model.evaluate(state_input_union, [mcts_probs_union, winner_union], batch_size=len(state_input), verbose=0)
action_probs, _ = self.model.predict_on_batch(state_input_union)
entropy = self_entropy(action_probs)
K.set_value(self.model.optimizer.lr, learning_rate)
self.model.fit(state_input_union, [mcts_probs_union, winner_union], batch_size=len(state_input), verbose=0)
return loss[0], entropy
self.train_step = train_step
def get_policy_param(self):
net_params = self.model.get_weights()
return net_params
def save_model(self, model_file):
""" save model params to file """
net_params = self.get_policy_param()
pickle.dump(net_params, open(model_file, 'wb'), protocol=2)