diff --git a/lzero/agent/__init__.py b/lzero/agent/__init__.py new file mode 100644 index 000000000..a6cbb38df --- /dev/null +++ b/lzero/agent/__init__.py @@ -0,0 +1 @@ +from .muzero import MuZeroAgent diff --git a/lzero/agent/config/__init__.py b/lzero/agent/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lzero/agent/config/muzero/__init__.py b/lzero/agent/config/muzero/__init__.py new file mode 100644 index 000000000..e8b937645 --- /dev/null +++ b/lzero/agent/config/muzero/__init__.py @@ -0,0 +1,8 @@ +from easydict import EasyDict +from . import gym_cartpole_v0 + +supported_env_cfg = { + gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg, +} + +supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/muzero/gym_cartpole_v0.py b/lzero/agent/config/muzero/gym_cartpole_v0.py new file mode 100644 index 000000000..c46a67a0d --- /dev/null +++ b/lzero/agent/config/muzero/gym_cartpole_v0.py @@ -0,0 +1,76 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='CartPole-v0-MuZero', + seed=0, + env=dict( + env_id='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/muzero.py b/lzero/agent/muzero.py new file mode 100644 index 000000000..4ddb436ae --- /dev/null +++ b/lzero/agent/muzero.py @@ -0,0 +1,422 @@ +import os +from functools import partial +from typing import Optional, Union, List + +import numpy as np +import torch +from ding.bonus.common import TrainingReturn, EvalReturn +from ding.config import save_config_py, compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from ditk import logging +from easydict import EasyDict +from tensorboardX import SummaryWriter + +from lzero.agent.config.muzero import supported_env_cfg +from lzero.entry.utils import log_buffer_memory_usage, random_collect +from lzero.mcts import MuZeroGameBuffer +from lzero.policy import visit_count_temperature +from lzero.policy.muzero import MuZeroPolicy +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator + + +class MuZeroAgent: + """ + Overview: + Agent class for executing MuZero algorithms which include methods for training, deployment, and batch evaluation. + Interfaces: + __init__, train, deploy, batch_evaluate + Properties: + best + + .. note:: + This agent class is tailored for use with the HuggingFace Model Zoo for LightZero + (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-MuZero), + and provides methods such as "train" and "deploy". + """ + + supported_env_list = list(supported_env_cfg.keys()) + + def __init__( + self, + env_id: str = None, + seed: int = 0, + exp_name: str = None, + model: Optional[torch.nn.Module] = None, + cfg: Optional[Union[EasyDict, dict]] = None, + policy_state_dict: str = None, + ) -> None: + """ + Overview: + Initialize the MuZeroAgent instance with environment parameters, model, and configuration. + Arguments: + - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. + - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. + - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. + - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. + - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. + - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. + + .. note:: + - If `env_id` is not specified, it must be included in `cfg`. + - The `supported_env_list` contains all the environment IDs that are supported by this agent. + """ + assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." + + if cfg is not None and not isinstance(cfg, EasyDict): + cfg = EasyDict(cfg) + + if env_id is not None: + assert env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format( + MuZeroAgent.supported_env_list + ) + if cfg is None: + cfg = supported_env_cfg[env_id] + else: + assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." + else: + assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." + assert cfg.main_config.env.env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format( + MuZeroAgent.supported_env_list + ) + default_policy_config = EasyDict({"policy": MuZeroPolicy.default_config()}) + default_policy_config.policy.update(cfg.main_config.policy) + cfg.main_config.policy = default_policy_config.policy + + if exp_name is not None: + cfg.main_config.exp_name = exp_name + self.origin_cfg = cfg + self.cfg = compile_config( + cfg.main_config, seed=seed, env=None, auto=True, policy=MuZeroPolicy, create_cfg=cfg.create_config + ) + self.exp_name = self.cfg.exp_name + + logging.getLogger().setLevel(logging.INFO) + self.seed = seed + set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) + if not os.path.exists(self.exp_name): + os.makedirs(self.exp_name) + save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) + if model is None: + if self.cfg.policy.model.model_type == 'mlp': + from lzero.model.muzero_model_mlp import MuZeroModelMLP + model = MuZeroModelMLP(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'conv': + from lzero.model.muzero_model import MuZeroModel + model = MuZeroModel(**self.cfg.policy.model) + else: + raise NotImplementedError + if self.cfg.policy.cuda and torch.cuda.is_available(): + self.cfg.policy.device = 'cuda' + else: + self.cfg.policy.device = 'cpu' + self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if policy_state_dict is not None: + self.policy.learn_mode.load_state_dict(policy_state_dict) + self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") + + self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + + def train( + self, + step: int = int(1e7), + ) -> TrainingReturn: + """ + Overview: + Train the agent through interactions with the environment. + Arguments: + - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). + Returns: + - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. + .. note:: + The method involves interacting with the environment, collecting experience, and optimizing the model. + """ + + collector_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] + ) + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + collector_env.seed(self.cfg.seed) + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') + ) if get_rank() == 0 else None + learner = BaseLearner( + self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name + ) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = MuZeroGameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if self.cfg.policy.update_per_collect is not None: + update_per_collect = self.cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if self.cfg.policy.random_collect_episode_num > 0: + random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if self.cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, self.policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if self.cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= step: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + + return TrainingReturn(wandb_url=None) + + def deploy( + self, + enable_save_replay: bool = False, + concatenate_all_replay: bool = False, + replay_save_path: str = None, + seed: Optional[Union[int, List]] = None, + debug: bool = False + ) -> EvalReturn: + """ + Overview: + Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the + agent will be evaluated. Average return and standard deviation of the return will be returned. + If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. + Arguments: + - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. + - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. + - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. + - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. + - debug (:obj:`bool`): Whether to enable the debug mode. Default to False. + Returns: + - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. + """ + + deply_configs = [self.evaluator_env_cfg[0]] + + if type(seed) == int: + seed_list = [seed] + elif seed: + seed_list = seed + else: + seed_list = [0] + + reward_list = [] + + if enable_save_replay: + replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( + self.exp_name, 'videos' + ) + deply_configs[0]['replay_path'] = replay_save_path + + for seed in seed_list: + + evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) + + evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) + set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=1, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + reward_list.extend(reward['eval_episode_return']) + + if enable_save_replay: + files = os.listdir(replay_save_path) + files = [file for file in files if file.endswith('0.mp4')] + files.sort() + if concatenate_all_replay: + # create a file named 'files.txt' to store the names of all mp4 files + with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: + for file in files: + f.write("file '{}'\n".format(file)) + + # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4' + os.system( + 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( + os.path.join(replay_save_path, 'files.txt'), replay_save_path + ) + ) + + return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) + + def batch_evaluate( + self, + n_evaluator_episode: int = None, + ) -> EvalReturn: + """ + Overview: + Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. + Arguments: + - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. + If None, uses default value from configuration. Defaults to None. + Returns: + - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. + + .. note:: + This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. + """ + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode + if n_evaluator_episode is None else n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + + return EvalReturn( + eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) + ) + + @property + def best(self): + """ + Overview: + Provides access to the best model according to evaluation metrics. + Returns: + - The agent with the best model loaded. + + .. note:: + The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. + When this property is accessed, the agent instance will load the best model state. + """ + + best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") + # Load best model if it exists + if os.path.exists(best_model_file_path): + policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) + self.policy.learn_mode.load_state_dict(policy_state_dict) + return self