-
Notifications
You must be signed in to change notification settings - Fork 6
/
MCTS.py
211 lines (178 loc) · 7.38 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import math
import time
import random
import numpy as np
import copy
import loss
C_PUCT = math.sqrt(2)
C_INIT = 1
# OBS: when the game is over it the algorithm expects that it is none to move
class Node:
def __init__(self, game, parent, action, probability=0, t=0, n=0):
self.parent = parent # This nodes parent
self.game = game # The game played
self.t = t # Sum of values from all tree searches through this node
self.n = n # Sum of visits from all tree searches through this node
self.last_action = action # action from parent board state to current board state
self.children = [] # List of children nodes
self.probability = probability # The probability of choosing this node from the parent node
if parent:
parent.add_child(self) # adds this node to the parents list of childs
self.game.execute_move(action) # executes the move for this node
self.board_state = np.copy(self.game.get_board()) # sets the board state for this node
self.turn = self.game.get_turn()
self.game.undo_move() # resets the games board state
else:
self.turn = game.get_turn()
def get_parent(self):
return self.parent
def add_child(self, child):
self.children.append(child)
def is_leaf_node(self):
if len(self.children) == 0:
return True
return False
def get_board_state(self):
return np.copy(self.board_state)
def get_last_action(self):
return self.last_action
def get_times_visited(self):
return self.n
def get_total_values(self):
return self.t
class MCTS:
def __init__(self, game, start_state, agent, Config):
self.root = Node(game, None, None)
self.game = game
self.Config = Config
self.root.board_state = np.copy(start_state)
self.agent = agent
self.T = 1
self.level = 0
# Fuction to reset the search and start from a new board_state
def reset_search(self):
self.root = Node(self.game, None, None)
self.root.board_state = self.game.get_board()
# Help function find_node_given_state
@staticmethod
def search_nodechildren_for_state(node, state):
for child in node.children:
if np.array_equal(child.get_board_state(), state):
return child
# Returns the node from input state
def find_node_given_state(self, state):
correct = None
start = self.root
correct = MCTS.search_nodechildren_for_state(start, state)
return correct
# Returns the most searched childe node from a node
def get_most_searched_child_node(self, node):
max_node = None
max_node_visits = 0
for child in node.children:
if child.get_times_visited() > max_node_visits:
max_node = child
max_node_visits = child.get_times_visited()
return max_node
# Returning a dictionary with action as key and visit number as value
def get_action_numbers(self, node):
action_numbers = {}
for i in range(self.Config.policy_output_dim):
action_numbers[i] = 0
for child in node.children:
action_numbers[child.last_action] = child.get_times_visited()
return action_numbers
# Returning the prior probabilities of a state, also known as the "raw" NN predictions
def get_prior_probabilities(self, board_state):
pred = self.agent.predict(board_state)
return loss.softmax(np.array(self.game.get_legal_NN_output()), pred[0]), pred[1]
# Returning the posterior search probabilities of the search,
# meaning that the percentages is calculated by: num_exec/total
def get_posterior_probabilities(self):
node = self.root
tot = 0
post_prob = np.zeros(self.Config.policy_output_dim)
actions = self.get_action_numbers(node)
for action in actions:
tot += actions[action]
for action in actions:
post_prob[self.Config.move_to_number(action)] = actions[action] / max(1, tot)
return post_prob
# Returning the temperature probabilities calculated from the number of searches for each action
def get_temperature_probabilities(self, node):
pi = {}
actions = self.get_action_numbers(node)
for action in actions:
pi[action] = (actions[action]) ** (1 / self.T)
return pi
# Returning a random move proportional to the temperature probabilities
def get_temperature_move(self, node):
pi = self.get_temperature_probabilities(node)
moves = [move for move in pi.keys()]
probs = [pi[key] for key in moves]
probs = np.array(probs)
probs = probs / sum(probs)
return np.random.choice(moves, p=probs)
# Returns the most seached move from a state based on the node given as input
def get_most_searched_move(self, node):
actions = self.get_action_numbers(node)
most_searched_move = 0
max = -1
# print(actions)
for action in actions:
if actions[action] > max:
most_searched_move = action
max = actions[action]
return most_searched_move
# Executing MCTS search a "number" times
def search_series(self, number):
for _ in range(number):
self.search()
# Executing a single MCTS search: Selection-Evaluation-Expansion-Backward pass
def search(self):
game = self.game
parent = self.root
while not parent.is_leaf_node():
best_puct = None
for child in parent.children:
curr_puct = self.PUCT(parent, child)
if (best_puct == None or curr_puct >= best_puct):
best_child = child
best_puct = curr_puct
self.level += 1
parent = best_child
self.game.execute_move(best_child.last_action)
raw_pred = self.agent.predict(np.array([game.get_board()]))
result = loss.softmax(np.array(game.get_legal_NN_output()), raw_pred[0])
if not self.game.is_final():
valid_moves = game.get_moves()
for move in valid_moves:
Node(game, parent, move, result[0][move])
self.back_propagate(parent, raw_pred[1][0][0])
self.level = 0
else:
self.back_propagate(parent, raw_pred[1][0][0])
self.level = 0
# Propagates the values from the search back into the tree
def back_propagate(self, node, t):
game = self.game
if game.is_final():
result = game.get_outcome()[node.parent.turn]
node.t += result
node.n += 1
game.undo_move()
self.back_propagate(node.get_parent(), -result)
else:
node.t += t
node.n += 1
if node.get_parent() is not None:
game.undo_move()
self.back_propagate(node.get_parent(), -t)
# Function for calculating the PUCT (The exploration vs exploitation function)
def PUCT(self, node, child):
N = child.n
sum_N_potential_actions = max(node.n - 1, 1)
exp = math.log(1 + sum_N_potential_actions + C_PUCT) / C_PUCT + C_INIT
U = exp * child.probability * math.sqrt(sum_N_potential_actions) / (1 + N)
Q = child.t / max(N, 1)
return Q + U