Skip to content

Commit

Permalink
polish(pu): polish action_type and env_type, fix test.yml, fix unitte…
Browse files Browse the repository at this point in the history
…st (#160)

* polish(pu): polish action_type and env_type for legal_action preprocessing

* polish(pu): use the latest clang in test

* fix(pu): use the latest clang in test on macOS

* fix(pu): use the latest clang in test on macOS

* fix(pu): use clang as c compiler in test on macOS

* test(pu): add test case of lunarlander_disc_gumbel_muzero_config when action_type is'varied_action_space'

* fix(pu): fix unittest

* polish(pu): undo wrongly modifications in cartpole and lunarlander

* fix(pu): fix requirements.txt and typo in ucb_score of ptree_az

* fix(pu): fix unittest in test_ding_env_wrapper
  • Loading branch information
puyuan1996 authored Dec 7, 2023
1 parent b6fd371 commit bbf371e
Show file tree
Hide file tree
Showing 43 changed files with 123 additions and 38 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ jobs:
shell: bash
run: |
brew install tree cloc wget curl make zip graphviz
brew install llvm # Install llvm (which includes clang)
echo 'export PATH="/usr/local/opt/llvm/bin:$PATH"' >> $GITHUB_ENV # update PATH
dot -V
- name: Set CC and CXX variables
run: |
echo "CC=$(which clang)" >> $GITHUB_ENV
echo "CXX=$(which clang++)" >> $GITHUB_ENV
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
2 changes: 1 addition & 1 deletion lzero/envs/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ def test(self):

obs = ding_env.reset()

assert isinstance(obs, np.ndarray)
assert isinstance(obs[0], np.ndarray)
action = ding_env.random_action()
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))
2 changes: 2 additions & 0 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, cfg: dict):
self._cfg = default_config
self._cfg = cfg
assert self._cfg.env_type in ['not_board_games', 'board_games']
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']

self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
Expand Down
5 changes: 3 additions & 2 deletions lzero/mcts/buffer/game_buffer_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, cfg: dict):
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
Expand Down Expand Up @@ -406,7 +407,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
else:
if self._cfg.mcts_ctree:
# cpp mcts_tree
if self._cfg.env_type == 'not_board_games':
if self._cfg.action_type == 'fixed_action_space':
sum_visits = sum(distributions)
policy = [visit_count / sum_visits for visit_count in distributions]
target_policies.append(policy)
Expand All @@ -421,7 +422,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
target_policies.append(policy_tmp)
else:
# python mcts_tree
if self._cfg.env_type == 'not_board_games':
if self._cfg.action_type == 'fixed_action_space':
sum_visits = sum(distributions)
policy = [visit_count / sum_visits for visit_count in distributions]
target_policies.append(policy)
Expand Down
5 changes: 3 additions & 2 deletions lzero/mcts/buffer/game_buffer_gumbel_muzero.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional
from typing import Any, Tuple

import numpy as np
from ding.utils import BUFFER_REGISTRY

from lzero.mcts.utils import prepare_observation
from lzero.mcts.buffer import MuZeroGameBuffer
from lzero.mcts.utils import prepare_observation


@BUFFER_REGISTRY.register('game_buffer_gumbel_muzero')
class GumbelMuZeroGameBuffer(MuZeroGameBuffer):
Expand Down
7 changes: 4 additions & 3 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, cfg: dict):
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
Expand Down Expand Up @@ -497,7 +498,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

# for board games
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
to_play_segment = policy_re_context # noqa
to_play_segment = policy_re_context
# transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
transition_batch_size = len(policy_obs_list)
game_segment_batch_size = len(pos_in_game_segment_list)
Expand Down Expand Up @@ -579,7 +580,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
)
else:
if self._cfg.env_type == 'not_board_games':
if self._cfg.action_type == 'fixed_action_space':
# for atari/classic_control/box2d environments that only have one player.
sum_visits = sum(distributions)
policy = [visit_count / sum_visits for visit_count in distributions]
Expand Down Expand Up @@ -657,7 +658,7 @@ def _compute_target_policy_non_reanalyzed(
policy_mask.append(1)
# NOTE: child_visit is already a distribution
distributions = child_visit[current_index]
if self._cfg.env_type == 'not_board_games':
if self._cfg.action_type == 'fixed_action_space':
# for atari/classic_control/box2d environments that only have one player.
target_policies.append(distributions)
else:
Expand Down
3 changes: 2 additions & 1 deletion lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self, cfg: dict):
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
Expand Down Expand Up @@ -540,7 +541,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
)
)
else:
if self._cfg.env_type == 'not_board_games':
if self._cfg.action_type == 'fixed_action_space':
sum_visits = sum(distributions)
policy = [visit_count / sum_visits for visit_count in distributions]
target_policies.append(policy)
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/game_buffer_stochastic_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, cfg: dict):
default_config.update(cfg)
self._cfg = default_config
assert self._cfg.env_type in ['not_board_games', 'board_games']
assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
self.replay_buffer_size = self._cfg.replay_buffer_size
self.batch_size = self._cfg.batch_size
self._alpha = self._cfg.priority_prob_alpha
Expand Down
14 changes: 8 additions & 6 deletions lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,14 @@ namespace tree
// After sorting, the first vector is the index, and the second vector is the probability value after perturbation sorted from large to small.
for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
{
// #ifdef __APPLE__
// disc_action_with_probs.__emplace_back(std::make_pair(iter, disturbed_probs[iter]));
// #else
// disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
// #endif
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));

#ifdef __GNUC__
// Use push_back for GCC
disc_action_with_probs.push_back(std::make_pair(iter, disturbed_probs[iter]));
#else
// Use emplace_back for other compilers
disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
#endif
}

std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
Expand Down
2 changes: 1 addition & 1 deletion lzero/mcts/ptree/ptree_az.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def _ucb_score(self, parent: Node, child: Node) -> float:
Overview:
Compute UCB score. The score for a node is based on its value, plus an exploration bonus based on the prior.
For more details, please refer to this paper: http://gauss.ececs.uc.edu/Workshops/isaim2010/papers/rosin.pdf
UCB = Q(s,a) + P(s,a) \cdot \frac{N(\text{parent})}{1+N(\text{child})} \cdot \left(c_1 + \log\left(\frac{N(\text{parent})+c_2+1}{c_2}\right)\right)
UCB = Q(s,a) + P(s,a) \cdot \frac{ \sqrt{N(\text{parent})}}{1+N(\text{child})} \cdot \left(c_1 + \log\left(\frac{N(\text{parent})+c_2+1}{c_2}\right)\right)
- Q(s,a): value of a child node.
- P(s,a): The prior of a child node.
- N(parent): The number of the visiting of the parent node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
channel_last=True,
scale=True,
stop_value=1,
alphazero_mcts_ctree=False,
save_replay_gif=False,
replay_path_gif='./replay_gif',
),
policy=dict(
sampled_algo=False,
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/tests/test_game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
replay_buffer_size=10000,
env_type='not_board_games',
use_priority=True,
action_type='fixed_action_space',
)
)

Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class EfficientZeroPolicy(MuZeroPolicy):
evaluator_env_num=3,
# (str) The type of environment. The options are ['not_board_games', 'board_games'].
env_type='not_board_games',
# (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space'].
action_type='fixed_action_space',
# (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode'].
battle_mode='play_with_bot_mode',
# (bool) Whether to monitor extra statistics in tensorboard.
Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/gumbel_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class GumbelMuZeroPolicy(MuZeroPolicy):
evaluator_env_num=3,
# (str) The type of environment. Options is ['not_board_games', 'board_games'].
env_type='not_board_games',
# (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space'].
action_type='fixed_action_space',
# (str) The type of battle mode. Options is ['play_with_bot_mode', 'self_play_mode'].
battle_mode='play_with_bot_mode',
# (bool) Whether to monitor extra statistics in tensorboard.
Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class MuZeroPolicy(Policy):
evaluator_env_num=3,
# (str) The type of environment. Options are ['not_board_games', 'board_games'].
env_type='not_board_games',
# (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space'].
action_type='fixed_action_space',
# (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode'].
battle_mode='play_with_bot_mode',
# (bool) Whether to monitor extra statistics in tensorboard.
Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class SampledEfficientZeroPolicy(MuZeroPolicy):
evaluator_env_num=3,
# (str) The type of environment. The options are ['not_board_games', 'board_games'].
env_type='not_board_games',
# (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space'].
action_type='fixed_action_space',
# (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode'].
battle_mode='play_with_bot_mode',
# (bool) Whether to monitor extra statistics in tensorboard.
Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/stochastic_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class StochasticMuZeroPolicy(MuZeroPolicy):
evaluator_env_num=3,
# (str) The type of environment. Options is ['not_board_games', 'board_games'].
env_type='not_board_games',
# (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space'].
action_type='fixed_action_space',
# (str) The type of battle mode. Options is ['play_with_bot_mode', 'self_play_mode'].
battle_mode='play_with_bot_mode',
# (bool) Whether to monitor extra statistics in tensorboard.
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
DI-engine[common_env]>=0.4.7
gym[accept-rom-license]==0.25.1
DI-engine>=0.4.7
gymnasium[atari]
numpy>=1.22.4
pympler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
scale=True,
screen_scaling=9,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
Expand All @@ -52,6 +53,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
batch_size = 256
max_env_step = int(1e6)
model_path = None
mcts_ctree = False

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -19,12 +21,31 @@
env=dict(
battle_mode='self_play_mode',
bot_action_type='rule',
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# ==============================================================
# for the creation of simulation env
agent_vs_human=False,
prob_random_agent=0,
prob_expert_agent=0,
prob_random_action_in_bot=0,
scale=True,
screen_scaling=9,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_name='connect4',
simulation_env_config_type='self_play',
# ==============================================================
model=dict(
observation_shape=(3, 6, 7),
action_space_size=7,
Expand All @@ -33,6 +54,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
update_per_collect=update_per_collect,
batch_size=batch_size,
optim_type='Adam',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
game_segment_length=int(6 * 7 / 2), # for battle_mode='play_with_bot_mode'
update_per_collect=update_per_collect,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
game_segment_length=int(6 * 7), # for battle_mode='self_play_mode'
update_per_collect=update_per_collect,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
batch_size = 256
max_env_step = int(5e5)
prob_random_action_in_bot = 0.5
mcts_ctree = True
mcts_ctree = False
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -37,6 +37,7 @@
scale=True,
screen_scaling=9,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
scale=True,
screen_scaling=9,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
game_segment_length=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode'
update_per_collect=update_per_collect,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
game_segment_length=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode'
update_per_collect=update_per_collect,
batch_size=batch_size,
Expand Down Expand Up @@ -85,5 +86,4 @@

if __name__ == "__main__":
from lzero.entry import train_muzero

train_muzero([main_config, create_config], seed=0, max_env_step=max_env_step)
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
),
cuda=True,
env_type='board_games',
action_type='varied_action_space',
game_segment_length=int(board_size * board_size), # for battle_mode='self_play_mode'
update_per_collect=update_per_collect,
batch_size=batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
scale=True,
check_action_to_connect4_in_bot_v0=False,
simulation_env_name="gomoku",
# ==============================================================
mcts_ctree=mcts_ctree,
screen_scaling=9,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
# ==============================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@
scale=True,
check_action_to_connect4_in_bot_v0=False,
simulation_env_name="gomoku",
# ==============================================================
mcts_ctree=mcts_ctree,
screen_scaling=9,
render_mode=None,
replay_path=None,
alphazero_mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
# ==============================================================
Expand Down
Loading

0 comments on commit bbf371e

Please sign in to comment.