-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmuzero_cli.py
213 lines (178 loc) · 15.6 KB
/
muzero_cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import sys
import gymnasium as gym
from monte_carlo_tree_search import *
from game import *
from replay_buffer import *
from muzero_model import *
from self_play import *
def main(cli_input):
################## CLI CHECK COMMAND AND OPEN JSON ##################
#lower case cli argument
cli_input_to_lower_case = list(map(lambda x: x.lower(), cli_input))
# find if one config file has been provide in clie
config_directory_and_file = list(filter(lambda s: 'config' in s, cli_input))
# check if train argument has been provide in the cli command
config_mode_train = list(filter(lambda s: 'train' in s, cli_input_to_lower_case))
# check if play argument has been provide in the cli command
config_mode_play = list(filter(lambda s: 'play' in s, cli_input_to_lower_case))
# check if report argument has been provide in the cli command
config_mode_report = list(filter(lambda s: 'report' in s, cli_input_to_lower_case))
# check if benchmark argument has been provide in the cli command
config_mode_benchmark = list(filter(lambda s: 'benchmark' in s, cli_input_to_lower_case))
#raise error/explain if config path is not provide
if len(config_directory_and_file) == 0 :
raise Exception("Specify a config directory and folder such as: config/config_file.json \
Example : \
python muzero_cli.py train config/config_file.json \
python muzero_cli.py train report config/config_file.json \
python muzero_cli.py train report play config/config_file.json \
python muzero_cli.py train play config/config_file.json \
python muzero_cli.py play config/config_file.json")
# raise error/explain if none of the minimal option has been provide
if len(config_mode_play + config_mode_train + config_mode_report) == 0 :
raise Exception("Specify a mode such as : train , train report , play , benchmark or any of this combination \
Example : \
python muzero_cli.py train config/config_file.json \
python muzero_cli.py train report config/config_file.json \
python muzero_cli.py train report play config/config_file.json \
python muzero_cli.py train play config/config_file.json \
python muzero_cli.py play config/config_file.json")
# open json config file from provider path.
with open(str(config_directory_and_file[0]), 'r') as openfile:
config = json.load(openfile)
#json lib already provide error if file not find
##########################################
#TYPE USE FOR TRAINING/INFERENCE/BENCHMARK
compute_type = torch.float32
#TODO EMBED PYTORCH WITH STR OPTION
#########################################
################## TRAIN ##################
if len(config_mode_train) > 0:
print("Start the training cycle...")
# # # set game environment from gym library
# # # render_mode should be set to None if you don't want rgb observation
# # # else 'human' or 'rgb_array' depending on ("human" for atari game)
env = gym.make(config["game"]["env"],render_mode=config["game"]["render"])
# # # the random seed are set to 0 for reproducibility purpose
# # # good reference about it at : https://pytorch.org/docs/stable/notes/randomness.html
np.random.seed(config["random_seed"]["np_random_seed"]) # set the random seed of numpy
torch.manual_seed(config["random_seed"]["torch_manual_seed"]) # set the random seed of pytorch
# # # init/set muzero model for training and inference
muzero = Muzero(model_structure = config["muzero"]["model_structure"], # 'vision_model' : will use rgb as observation , 'mlp_model' : will use game state as observation
observation_space_dimensions = env.observation_space, # dimension of the observation
action_space_dimensions = env.action_space, # dimension of the action allow (gym box/discrete)
state_space_dimensions = config["muzero"]["state_space_dimensions"], # support size / encoding space
hidden_layer_dimensions = config["muzero"]["hidden_layer_dimensions"], # number of weight in the recursive layer of the mlp
number_of_hidden_layer = config["muzero"]["number_of_hidden_layer"], # number of recusion layer of hidden layer of the mlp
k_hypothetical_steps = config["muzero"]["k_hypothetical_steps"], # number of future step you want to be simulate during train (they are mainly support loss)
optimizer = config["muzero"]["optimizer"],
lr_scheduler = config["muzero"]["lr_scheduler"],
learning_rate = config["muzero"]["learning_rate"], # learning rate of the optimizer
loss_type = config["muzero"]["loss_type"],
num_of_epoch = config["muzero"]["num_of_epoch"], # number of step during training (the number of step of self play and training can be change)
device = config["muzero"]["device"], # device on which you want the comput to be made : "cpu" , "cuda:0" , "cuda:1" , etc
type_format = compute_type, # choice the dtype of the model. look at [https://pytorch.org/docs/1.8.1/amp.html#ops-that-can-autocast-to-float16]
load = config["muzero"]["load"], # function for loading a save model
use_amp = config["muzero"]["use_amp"], # use mix precision for gpu (not implement yet)
scaler_on = config["muzero"]["scaler_on"], # scale gradient to reduce computation
bin_method = config["muzero"]["bin_method"], # "linear_bin" , "uniform_bin" : will have a regular incrementation of action or uniform sampling(pick randomly) from the bound
bin_decomposition_number = config["muzero"]["bin_decomposition_number"]) # number of action to sample from low/high bound of a gym discret box
# # # init/set the game storage(stor each game) and dataset(create dataset) generate during training
replay_buffer = ReplayBuffer(window_size = config["replaybuffer"]["window_size"], # number of game store in the buffer
batch_size = config["replaybuffer"]["batch_size"], # batch size is the number of observe game during train
num_unroll = muzero.k_hypothetical_steps, # number of mouve/play store inside the batched game
td_steps = config["replaybuffer"]["td_steps"], # number of step the value is scale on
game_sampling = config["replaybuffer"]["game_sampling"], # 'uniform' or "priority" (will game randomly or with a priority distribution)
position_sampling = config["replaybuffer"]["position_sampling"]) # 'uniform' or "priority" (will sample position in game randomly or with a priority distribution)
# # # init/set the monte carlos tree search parameter
mcts = Monte_carlo_tree_search(pb_c_base = config["monte_carlo_tree_search"]["pb_c_base"] ,
pb_c_init = config["monte_carlo_tree_search"]["pb_c_init"],
discount = config["monte_carlo_tree_search"]["discount"],
root_dirichlet_alpha = config["monte_carlo_tree_search"]["root_dirichlet_alpha"],
root_exploration_fraction = config["monte_carlo_tree_search"]["root_exploration_fraction"])
# # # ini/set the Game class which embbed the gym game class function
gameplay = Game(gym_env = env,
discount = config["gameplay"]["discount"], #should be the same discount than mcts
limit_of_game_play = config["gameplay"]["limit_of_game_play"], # maximum number of mouve
observation_dimension = muzero.observation_dimension,
action_dimension = muzero.action_dimension,
rgb_observation = muzero.is_RGB,
action_map = muzero.action_dictionnary)
# # # train model (if you choice vison model it will render the game)
epoch_pr , loss , reward = learning_cycle(number_of_iteration = config["learning_cycle"]["number_of_iteration"], # number of epoch(step) in muzero should be the |total amount of number_of_iteration x number_of_training_before_self_play|
number_of_self_play_before_training = config["learning_cycle"]["number_of_self_play_before_training"], # number of game played record in the replay buffer before training
number_of_training_before_self_play = config["learning_cycle"]["number_of_training_before_self_play"], # number of epoch made by the model before selplay
number_of_mcts_simulation = config["learning_cycle"]["number_of_mcts_simulation"],
model_tag_number = config["learning_cycle"]["model_tag_number"], # tag number use to generate checkpoint
number_of_worker_selfplay = config["learning_cycle"]["number_of_worker_selfplay"],
tempererature_type = config["learning_cycle"]["tempererature_type"], # "static_temperature" ,"linear_decrease_temperature" , "extreme_temperature" and "reversal_tanh_temperature"
verbose = config["learning_cycle"]["verbose"], # if you want to print the epoch|reward|loss during train
muzero_model = muzero,
gameplay = gameplay,
monte_carlo_tree_search = mcts,
replay_buffer = replay_buffer)
print("Training end.")
################## REPORT ##################
if len(config_mode_train) > 0 and len(config_mode_report) > 0:
print("Creating report...")
report( muzero, replay_buffer, epoch_pr, loss, reward)
print("Report created")
################## INFERENCE_FROM_CHECKPOINT ##################
if len(config_mode_play) > 0:
print("Start play...")
play_game_from_checkpoint(game_to_play = config["game"]["env"],
model_tag = config["play_game_from_checkpoint"]["model_tag"],
model_device = config["play_game_from_checkpoint"]["model_device"],
model_type = compute_type,
mcts_pb_c_base = config["monte_carlo_tree_search"]["pb_c_base"] ,
mcts_pb_c_init = config["monte_carlo_tree_search"]["pb_c_init"],
mcts_discount = config["monte_carlo_tree_search"]["discount"],
mcts_root_dirichlet_alpha = config["monte_carlo_tree_search"]["root_dirichlet_alpha"],
mcts_root_exploration_fraction = config["monte_carlo_tree_search"]["root_exploration_fraction"],
mcts_with_or_without_dirichlet_noise = config["play_game_from_checkpoint"]["mcts_with_or_without_dirichlet_noise"],
number_of_monte_carlo_tree_search_simulation = config["play_game_from_checkpoint"]["number_of_monte_carlo_tree_search_simulation"],
gameplay_discount = config["gameplay"]["discount"],
temperature = config["play_game_from_checkpoint"]["temperature"],
game_iter = config["play_game_from_checkpoint"]["game_iter"],
slow_mo_in_second = config["play_game_from_checkpoint"]["slow_mo_in_second"],
render = config["play_game_from_checkpoint"]["render"],
verbose = config["play_game_from_checkpoint"]["verbose"])
print("End play")
################## BENCHMARK_FROM_CHECKPOINT ##################
if len(config_mode_benchmark) > 0:
print("Start benchmark...")
number_of_trial = 100
cache_t,cache_r,cache_a,cache_p = [],[],[],[]
for _ in range(number_of_trial):
tag , reward , action, policy = play_game_from_checkpoint(game_to_play = config["game"]["env"],
model_tag = config["play_game_from_checkpoint"]["model_tag"],
model_device = config["play_game_from_checkpoint"]["model_device"],
model_type = compute_type,
mcts_pb_c_base = config["monte_carlo_tree_search"]["pb_c_base"] ,
mcts_pb_c_init = config["monte_carlo_tree_search"]["pb_c_init"],
mcts_discount = config["monte_carlo_tree_search"]["discount"],
mcts_root_dirichlet_alpha = config["monte_carlo_tree_search"]["root_dirichlet_alpha"],
mcts_root_exploration_fraction = config["monte_carlo_tree_search"]["root_exploration_fraction"],
mcts_with_or_without_dirichlet_noise = config["play_game_from_checkpoint"]["mcts_with_or_without_dirichlet_noise"],
number_of_monte_carlo_tree_search_simulation = config["play_game_from_checkpoint"]["number_of_monte_carlo_tree_search_simulation"],
gameplay_discount = config["gameplay"]["discount"],
temperature = config["play_game_from_checkpoint"]["temperature"],
game_iter = config["play_game_from_checkpoint"]["game_iter"],
slow_mo_in_second = 0,
render = False,
verbose = False,
benchmark = True) # Need benchmark True to return output
#could do it in one list or even wrap the play_game with benchmark but it reduce clarity
cache_t.append(tag)
cache_r.append(reward)
cache_a.append(action)
cache_p.append(policy)
benchmark(cache_t,
cache_r,
cache_a,
cache_p,
folder = "report",
verbose = True)
print("End benchmark")
if __name__ == "__main__":
main(sys.argv[:])