diff --git a/README.md b/README.md index 7274cffbdb..d3ee04bea8 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ All the following examples can be executed online using Google colab notebooks: | ------------------- | ---------------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | | A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | ACER | :heavy_check_mark: | :heavy_check_mark: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | -| ACKTR | :heavy_check_mark: | :x: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | +| ACKTR | :heavy_check_mark: | :heavy_check_mark: | :x: (5) | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | DDPG | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: | | DQN | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :x: | :x: | | GAIL (2) | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: (4) | diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a0ddc324d6..0135b884fe 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -9,8 +9,14 @@ Pre-Release 2.4.1a (WIP) -------------------------- - fixed computation of training metrics in TRPO and PPO1 +- added ``reset_num_timesteps`` keyword when calling train() to continue tensorboard learning curves +- reduced the size taken by tensorboard logs (added a ``full_tensorboard_log`` to enable full logging, which was the previous behavior) +- fixed image detection for tensorboard logging +- fixed ACKTR for recurrent policies +- fixed gym breaking changes - fixed custom policy examples in the doc for DQN and DDPG - remove gym spaces patch for equality functions +- fixed tensorflow dependency: cpu version was installed overwritting tensorflow-gpu when present. Release 2.4.0 (2019-01-17) diff --git a/setup.py b/setup.py index 8937299b56..d276f6d039 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,35 @@ -from setuptools import setup, find_packages import sys +import subprocess +from setuptools import setup, find_packages +from distutils.version import LooseVersion if sys.version_info.major != 3: print('This Python is only compatible with Python 3, but you are running ' 'Python {}. The installation will likely fail.'.format(sys.version_info.major)) +# Check tensorflow installation to avoid +# breaking pre-installed tf gpu +install_tf, tf_gpu = False, False +try: + import tensorflow as tf + if tf.__version__ < LooseVersion('1.5.0'): + install_tf = True + # check if a gpu version is needed + tf_gpu = tf.test.is_gpu_available() +except ImportError: + install_tf = True + # Check if a nvidia gpu is present + for command in ['nvidia-smi', '/usr/bin/nvidia-smi', 'nvidia-smi.exe']: + if subprocess.call([command]) == 0: + tf_gpu = True + break + +tf_dependency = [] +if install_tf: + tf_dependency = ['tensorflow-gpu>=1.5.0'] if tf_gpu else ['tensorflow>=1.5.0'] + if tf_gpu: + print("A GPU was detected, tensorflow-gpu will be installed") + long_description = """ [![Build Status](https://travis-ci.com/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.com/hill-a/stable-baselines) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines.readthedocs.io/en/master/?badge=master) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&utm_medium=referral&utm_content=hill-a/stable-baselines&utm_campaign=Badge_Grade) [![Codacy Badge](https://api.codacy.com/project/badge/Coverage/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&utm_medium=referral&utm_content=hill-a/stable-baselines&utm_campaign=Badge_Coverage) @@ -83,7 +108,6 @@ 'progressbar2', 'mpi4py', 'cloudpickle>=0.5.5', - 'tensorflow>=1.5.0', 'click', 'opencv-python', 'numpy', @@ -91,7 +115,7 @@ 'matplotlib', 'seaborn', 'glob2' - ], + ] + tf_dependency, extras_require={ 'tests': [ 'pytest==3.5.1', @@ -112,7 +136,7 @@ license="MIT", long_description=long_description, long_description_content_type='text/markdown', - version="2.4.1a", + version="2.4.1a0", ) # python setup.py sdist diff --git a/stable_baselines/__init__.py b/stable_baselines/__init__.py index 00d6700556..5cbf3d7a3e 100644 --- a/stable_baselines/__init__.py +++ b/stable_baselines/__init__.py @@ -9,4 +9,4 @@ from stable_baselines.trpo_mpi import TRPO from stable_baselines.sac import SAC -__version__ = "2.4.1a" +__version__ = "2.4.1a0" diff --git a/stable_baselines/a2c/a2c.py b/stable_baselines/a2c/a2c.py index e54a157c00..fe41af4040 100644 --- a/stable_baselines/a2c/a2c.py +++ b/stable_baselines/a2c/a2c.py @@ -35,11 +35,13 @@ class A2C(ActorCriticRLModel): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance (used only for loading) :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5, learning_rate=7e-4, alpha=0.99, epsilon=1e-5, lr_schedule='linear', verbose=0, tensorboard_log=None, - _init_setup_model=True, policy_kwargs=None): + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super(A2C, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) @@ -54,6 +56,7 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0. self.lr_schedule = lr_schedule self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log self.graph = None self.sess = None @@ -132,15 +135,16 @@ def setup_model(self): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph)) - tf.summary.histogram('discounted_rewards', self.rewards_ph) tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate)) - tf.summary.histogram('learning_rate', self.learning_rate) tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph)) - tf.summary.histogram('advantage', self.advs_ph) - if len(self.observation_space.shape) == 3: - tf.summary.image('observation', train_model.obs_ph) - else: - tf.summary.histogram('observation', train_model.obs_ph) + if self.full_tensorboard_log: + tf.summary.histogram('discounted_rewards', self.rewards_ph) + tf.summary.histogram('learning_rate', self.learning_rate) + tf.summary.histogram('advantage', self.advs_ph) + if tf_util.is_image(self.observation_space): + tf.summary.image('observation', train_model.obs_ph) + else: + tf.summary.histogram('observation', train_model.obs_ph) trainer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_ph, decay=self.alpha, epsilon=self.epsilon) @@ -184,7 +188,7 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ if writer is not None: # run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...) - if (1 + update) % 10 == 0: + if self.full_tensorboard_log and (1 + update) % 10 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, policy_loss, value_loss, policy_entropy, _ = self.sess.run( @@ -202,8 +206,13 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ return policy_loss, value_loss, policy_entropy - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="A2C"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="A2C", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) self.learning_rate_schedule = Scheduler(initial_value=self.learning_rate, n_values=total_timesteps, @@ -216,8 +225,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ for update in range(1, total_timesteps // self.n_batch + 1): # true_reward is the reward without discount obs, states, rewards, masks, actions, values, true_reward = runner.run() - _, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values, update, - writer) + _, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values, + self.num_timesteps // (self.n_batch + 1), writer) n_seconds = time.time() - t_start fps = int((update * self.n_batch) / n_seconds) @@ -225,18 +234,20 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ self.episode_reward = total_episode_reward_logger(self.episode_reward, true_reward.reshape((self.n_envs, self.n_steps)), masks.reshape((self.n_envs, self.n_steps)), - writer, update * (self.n_batch + 1)) + writer, self.num_timesteps) + + self.num_timesteps += self.n_batch + 1 if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break if self.verbose >= 1 and (update % log_interval == 0 or update == 1): explained_var = explained_variance(values, rewards) logger.record_tabular("nupdates", update) - logger.record_tabular("total_timesteps", update * self.n_batch) + logger.record_tabular("total_timesteps", self.num_timesteps) logger.record_tabular("fps", fps) logger.record_tabular("policy_entropy", float(policy_entropy)) logger.record_tabular("value_loss", float(value_loss)) diff --git a/stable_baselines/acer/acer_simple.py b/stable_baselines/acer/acer_simple.py index 04e4c1d964..0f469d7c40 100644 --- a/stable_baselines/acer/acer_simple.py +++ b/stable_baselines/acer/acer_simple.py @@ -91,12 +91,15 @@ class ACER(ActorCriticRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, ent_coef=0.01, max_grad_norm=10, learning_rate=7e-4, lr_schedule='linear', rprop_alpha=0.99, rprop_epsilon=1e-5, buffer_size=5000, - replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True, alpha=0.99, delta=1, - verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): + replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True, + alpha=0.99, delta=1, verbose=0, tensorboard_log=None, + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super(ACER, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) @@ -119,6 +122,7 @@ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, self.lr_schedule = lr_schedule self.num_procs = num_procs self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log self.graph = None self.sess = None @@ -361,17 +365,19 @@ def custom_getter(getter, name, *args, **kwargs): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('rewards', tf.reduce_mean(self.reward_ph)) - tf.summary.histogram('rewards', self.reward_ph) tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate)) - tf.summary.histogram('learning_rate', self.learning_rate) tf.summary.scalar('advantage', tf.reduce_mean(adv)) - tf.summary.histogram('advantage', adv) tf.summary.scalar('action_probabilty', tf.reduce_mean(self.mu_ph)) - tf.summary.histogram('action_probabilty', self.mu_ph) - if len(self.observation_space.shape) == 3: - tf.summary.image('observation', train_model.obs_ph) - else: - tf.summary.histogram('observation', train_model.obs_ph) + + if self.full_tensorboard_log: + tf.summary.histogram('rewards', self.reward_ph) + tf.summary.histogram('learning_rate', self.learning_rate) + tf.summary.histogram('advantage', adv) + tf.summary.histogram('action_probabilty', self.mu_ph) + if tf_util.is_image(self.observation_space): + tf.summary.image('observation', train_model.obs_ph) + else: + tf.summary.histogram('observation', train_model.obs_ph) trainer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_ph, decay=self.rprop_alpha, epsilon=self.rprop_epsilon) @@ -429,7 +435,7 @@ def _train_step(self, obs, actions, rewards, dones, mus, states, masks, steps, w if writer is not None: # run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...) - if (1 + (steps / self.n_batch)) % 10 == 0: + if self.full_tensorboard_log and (1 + (steps / self.n_batch)) % 10 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() step_return = self.sess.run([self.summary] + self.run_ops, td_map, options=run_options, @@ -444,8 +450,13 @@ def _train_step(self, obs, actions, rewards, dones, mus, states, masks, steps, w return self.names_ops, step_return[1:] # strip off _train - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="ACER"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="ACER", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) self.learning_rate_schedule = Scheduler(initial_value=self.learning_rate, n_values=total_timesteps, @@ -474,7 +485,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ self.episode_reward = total_episode_reward_logger(self.episode_reward, rewards.reshape((self.n_envs, self.n_steps)), dones.reshape((self.n_envs, self.n_steps)), - writer, steps) + writer, self.num_timesteps) # reshape stuff correctly obs = obs.reshape(runner.batch_ob_shape) @@ -485,16 +496,16 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ masks = masks.reshape([runner.batch_ob_shape[0]]) names_ops, values_ops = self._train_step(obs, actions, rewards, dones, mus, self.initial_state, masks, - steps, writer) + self.num_timesteps, writer) if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break if self.verbose >= 1 and (int(steps / runner.n_batch) % log_interval == 0): - logger.record_tabular("total_timesteps", steps) + logger.record_tabular("total_timesteps", self.num_timesteps) logger.record_tabular("fps", int(steps / (time.time() - t_start))) # IMP: In EpisodicLife env, during training, we get done=True at each loss of life, # not just at the terminal state. Thus, this is mean until end of life, not end of episode. @@ -519,7 +530,10 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ dones = dones.reshape([runner.n_batch]) masks = masks.reshape([runner.batch_ob_shape[0]]) - self._train_step(obs, actions, rewards, dones, mus, self.initial_state, masks, steps) + self._train_step(obs, actions, rewards, dones, mus, self.initial_state, masks, + self.num_timesteps) + + self.num_timesteps += self.n_batch return self diff --git a/stable_baselines/acktr/acktr_disc.py b/stable_baselines/acktr/acktr_disc.py index de9134f135..e289491fa4 100644 --- a/stable_baselines/acktr/acktr_disc.py +++ b/stable_baselines/acktr/acktr_disc.py @@ -39,11 +39,14 @@ class ACKTR(ActorCriticRLModel): :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param async_eigen_decomp: (bool) Use async eigen decomposition :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01, vf_coef=0.25, vf_fisher_coef=1.0, learning_rate=0.25, max_grad_norm=0.5, kfac_clip=0.001, lr_schedule='linear', verbose=0, - tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False, policy_kwargs=None): + tensorboard_log=None, _init_setup_model=True, async_eigen_decomp=False, + policy_kwargs=None, full_tensorboard_log=False): super(ACKTR, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) @@ -60,6 +63,7 @@ def __init__(self, policy, env, gamma=0.99, nprocs=1, n_steps=20, ent_coef=0.01, self.nprocs = nprocs self.tensorboard_log = tensorboard_log self.async_eigen_decomp = async_eigen_decomp + self.full_tensorboard_log = full_tensorboard_log self.graph = None self.sess = None @@ -160,15 +164,17 @@ def setup_model(self): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph)) - tf.summary.histogram('discounted_rewards', self.rewards_ph) tf.summary.scalar('learning_rate', tf.reduce_mean(self.pg_lr_ph)) - tf.summary.histogram('learning_rate', self.pg_lr_ph) tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph)) - tf.summary.histogram('advantage', self.advs_ph) - if len(self.observation_space.shape) == 3: - tf.summary.image('observation', train_model.obs_ph) - else: - tf.summary.histogram('observation', train_model.obs_ph) + + if self.full_tensorboard_log: + tf.summary.histogram('discounted_rewards', self.rewards_ph) + tf.summary.histogram('learning_rate', self.pg_lr_ph) + tf.summary.histogram('advantage', self.advs_ph) + if tf_util.is_image(self.observation_space): + tf.summary.image('observation', train_model.obs_ph) + else: + tf.summary.histogram('observation', train_model.obs_ph) with tf.variable_scope("kfac", reuse=False, custom_getter=tf_util.outer_scope_getter("kfac")): with tf.device('/gpu:0'): @@ -219,7 +225,7 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ if writer is not None: # run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...) - if (1 + update) % 10 == 0: + if self.full_tensorboard_log and (1 + update) % 10 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, policy_loss, value_loss, policy_entropy, _ = self.sess.run( @@ -236,8 +242,13 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ return policy_loss, value_loss, policy_entropy - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="ACKTR"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="ACKTR", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) self.n_batch = self.n_envs * self.n_steps @@ -282,7 +293,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ # true_reward is the reward without discount obs, states, rewards, masks, actions, values, true_reward = runner.run() policy_loss, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values, - update, writer) + self.num_timesteps // (self.n_batch + 1), + writer) n_seconds = time.time() - t_start fps = int((update * self.n_batch) / n_seconds) @@ -290,18 +302,18 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ self.episode_reward = total_episode_reward_logger(self.episode_reward, true_reward.reshape((self.n_envs, self.n_steps)), masks.reshape((self.n_envs, self.n_steps)), - writer, update * (self.n_batch + 1)) + writer, self.num_timesteps) if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break if self.verbose >= 1 and (update % log_interval == 0 or update == 1): explained_var = explained_variance(values, rewards) logger.record_tabular("nupdates", update) - logger.record_tabular("total_timesteps", update * self.n_batch) + logger.record_tabular("total_timesteps", self.num_timesteps) logger.record_tabular("fps", fps) logger.record_tabular("policy_entropy", float(policy_entropy)) logger.record_tabular("policy_loss", float(policy_loss)) @@ -309,6 +321,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ logger.record_tabular("explained_variance", float(explained_var)) logger.dump_tabular() + self.num_timesteps += self.n_batch + 1 + coord.request_stop() coord.join(enqueue_threads) diff --git a/stable_baselines/acktr/kfac.py b/stable_baselines/acktr/kfac.py index 8f70143285..607423e6a3 100644 --- a/stable_baselines/acktr/kfac.py +++ b/stable_baselines/acktr/kfac.py @@ -158,7 +158,8 @@ def _search_factors(gradient, graph): if len(b_tensor.get_shape()) > 0 and b_tensor.get_shape()[0].value is None: b_tensor.set_shape(b_tensor_shape) b_tensors.append(b_tensor) - fprop_op_name = op_types.append('UNK-' + fprop_op.op_def.name) + fprop_op_name = 'UNK-' + fprop_op.op_def.name + op_types.append(fprop_op_name) return {'opName': fprop_op_name, 'op': fprop_op, 'fpropFactors': f_tensors, 'bpropFactors': b_tensors} diff --git a/stable_baselines/common/base_class.py b/stable_baselines/common/base_class.py index ed01fcae8d..a54a00825d 100644 --- a/stable_baselines/common/base_class.py +++ b/stable_baselines/common/base_class.py @@ -39,6 +39,7 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base, pol self.action_space = None self.n_envs = None self._vectorize_action = False + self.num_timesteps = 0 if env is not None: if isinstance(env, str): @@ -115,6 +116,21 @@ def set_env(self, env): self.env = env + def _init_num_timesteps(self, reset_num_timesteps=True): + """ + Initialize and resets num_timesteps (total timesteps since beginning of training) + if needed. Mainly used logging and plotting (tensorboard). + + :param reset_num_timesteps: (bool) Set it to false when continuing training + to not create new plotting curves in tensorboard. + :return: (bool) Whether a new tensorboard log needs to be created + """ + if reset_num_timesteps: + self.num_timesteps = 0 + + new_tb_log = self.num_timesteps == 0 + return new_tb_log + @abstractmethod def setup_model(self): """ @@ -135,7 +151,8 @@ def _setup_learn(self, seed): set_global_seeds(seed) @abstractmethod - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run"): + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run", + reset_num_timesteps=True): """ Return a trained model. @@ -145,6 +162,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ It takes the local and global variables. If it returns False, training is aborted. :param log_interval: (int) The number of timesteps before logging. :param tb_log_name: (str) the name of the run for tensorboard log + :param reset_num_timesteps: (bool) whether or not to reset the current timestep number (used in logging) :return: (BaseRLModel) the trained model """ pass @@ -333,7 +351,8 @@ def setup_model(self): pass @abstractmethod - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run"): + def learn(self, total_timesteps, callback=None, seed=None, + log_interval=100, tb_log_name="run", reset_num_timesteps=True): pass def predict(self, observation, state=None, mask=None, deterministic=False): @@ -468,7 +487,8 @@ def setup_model(self): pass @abstractmethod - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run"): + def learn(self, total_timesteps, callback=None, seed=None, + log_interval=100, tb_log_name="run", reset_num_timesteps=True): pass @abstractmethod @@ -544,23 +564,27 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TensorboardWriter: - def __init__(self, graph, tensorboard_log_path, tb_log_name): + def __init__(self, graph, tensorboard_log_path, tb_log_name, new_tb_log=True): """ Create a Tensorboard writer for a code segment, and saves it to the log directory as its own run :param graph: (Tensorflow Graph) the model graph :param tensorboard_log_path: (str) the save path for the log (can be None for no logging) :param tb_log_name: (str) the name of the run for tensorboard log + :param new_tb_log: (bool) whether or not to create a new logging folder for tensorbaord """ self.graph = graph self.tensorboard_log_path = tensorboard_log_path self.tb_log_name = tb_log_name self.writer = None + self.new_tb_log = new_tb_log def __enter__(self): if self.tensorboard_log_path is not None: - save_path = os.path.join(self.tensorboard_log_path, - "{}_{}".format(self.tb_log_name, self._get_latest_run_id() + 1)) + latest_run_id = self._get_latest_run_id() + if self.new_tb_log: + latest_run_id = latest_run_id + 1 + save_path = os.path.join(self.tensorboard_log_path, "{}_{}".format(self.tb_log_name, latest_run_id)) self.writer = tf.summary.FileWriter(save_path, graph=self.graph) return self.writer diff --git a/stable_baselines/common/distributions.py b/stable_baselines/common/distributions.py index 9b14591ba9..354102cd3d 100644 --- a/stable_baselines/common/distributions.py +++ b/stable_baselines/common/distributions.py @@ -186,7 +186,10 @@ def __init__(self, n_vec): :param n_vec: ([int]) the vectors """ - self.n_vec = n_vec + # Cast the variable because tf does not allow uint32 + self.n_vec = n_vec.astype(np.int32) + # Check that the cast was valid + assert (self.n_vec > 0).all(), "Casting uint32 to int32 was invalid" def probability_distribution_class(self): return MultiCategoricalProbabilityDistribution diff --git a/stable_baselines/common/misc_util.py b/stable_baselines/common/misc_util.py index 5532805168..e3ef88c4f8 100644 --- a/stable_baselines/common/misc_util.py +++ b/stable_baselines/common/misc_util.py @@ -89,7 +89,9 @@ def set_global_seeds(seed): tf.set_random_seed(seed) np.random.seed(seed) random.seed(seed) - gym.spaces.prng.seed(seed) + # prng was removed in latest gym version + if hasattr(gym.spaces, 'prng'): + gym.spaces.prng.seed(seed) def pretty_eta(seconds_left): diff --git a/stable_baselines/common/tf_util.py b/stable_baselines/common/tf_util.py index d7a3d95f06..fd64a5147d 100644 --- a/stable_baselines/common/tf_util.py +++ b/stable_baselines/common/tf_util.py @@ -11,6 +11,19 @@ from stable_baselines import logger +def is_image(tensor): + """ + Check if a tensor has the shape of + a valid image for tensorboard logging. + Valid image: RGB, RGBD, GrayScale + + :param tensor: (np.ndarray or tf.placeholder) + :return: (bool) + """ + + return len(tensor.shape) == 3 and tensor.shape[-1] in [1, 3, 4] + + def switch(condition, then_expression, else_expression): """ Switches between two operations depending on a scalar value (int or bool). @@ -210,7 +223,7 @@ def function(inputs, outputs, updates=None, givens=None): Take a bunch of tensorflow placeholders and expressions computed based on those placeholders and produces f(inputs) -> outputs. Function f takes values to be fed to the input's placeholders and produces the values of the expressions - in outputs. Just like a Theano function. + in outputs. Just like a Theano function. Input values can be passed in the same order as inputs or can be provided as kwargs based on placeholder name (passed to constructor or accessible via placeholder.op.name). @@ -225,13 +238,13 @@ def function(inputs, outputs, updates=None, givens=None): >>> assert lin(2) == 6 >>> assert lin(x=3) == 9 >>> assert lin(2, 2) == 10 - + :param inputs: (TensorFlow Tensor or Object with make_feed_dict) list of input arguments :param outputs: (TensorFlow Tensor) list of outputs or a single output to be returned from function. Returned value will also have the same shape. :param updates: ([tf.Operation] or tf.Operation) list of update functions or single update function that will be run whenever - the function is called. The return is ignored. + the function is called. The return is ignored. :param givens: (dict) the values known for the output """ if isinstance(outputs, list): @@ -254,7 +267,7 @@ def __init__(self, inputs, outputs, updates, givens): value will also have the same shape. :param updates: ([tf.Operation] or tf.Operation) list of update functions or single update function that will be run whenever - the function is called. The return is ignored. + the function is called. The return is ignored. :param givens: (dict) the values known for the output """ for inpt in inputs: diff --git a/stable_baselines/ddpg/ddpg.py b/stable_baselines/ddpg/ddpg.py index ef64d4404e..f69dce5f40 100644 --- a/stable_baselines/ddpg/ddpg.py +++ b/stable_baselines/ddpg/ddpg.py @@ -168,6 +168,8 @@ class DDPG(OffPolicyRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, nb_train_steps=50, @@ -176,7 +178,7 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n normalize_returns=False, enable_popart=False, observation_range=(-5., 5.), critic_l2_reg=0., return_range=(-np.inf, np.inf), actor_lr=1e-4, critic_lr=1e-3, clip_norm=None, reward_scale=1., render=False, render_eval=False, memory_limit=100, verbose=0, tensorboard_log=None, - _init_setup_model=True, policy_kwargs=None): + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): # TODO: replay_buffer refactoring super(DDPG, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DDPGPolicy, @@ -208,6 +210,7 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n self.nb_rollout_steps = nb_rollout_steps self.memory_limit = memory_limit self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log # init self.graph = None @@ -260,6 +263,7 @@ def __init__(self, policy, env, gamma=0.99, memory_policy=None, eval_env=None, n self.summary = None self.episode_reward = None self.tb_seen_steps = None + self.target_params = None if _init_setup_model: @@ -361,7 +365,8 @@ def setup_model(self): self.target_q = self.rewards + (1. - self.terminals1) * self.gamma * q_obs1 tf.summary.scalar('critic_target', tf.reduce_mean(self.critic_target)) - tf.summary.histogram('critic_target', self.critic_target) + if self.full_tensorboard_log: + tf.summary.histogram('critic_target', self.critic_target) # Set up parts. if self.normalize_returns and self.enable_popart: @@ -371,13 +376,15 @@ def setup_model(self): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('rewards', tf.reduce_mean(self.rewards)) - tf.summary.histogram('rewards', self.rewards) tf.summary.scalar('param_noise_stddev', tf.reduce_mean(self.param_noise_stddev)) - tf.summary.histogram('param_noise_stddev', self.param_noise_stddev) - if len(self.observation_space.shape) == 3 and self.observation_space.shape[0] in [1, 3, 4]: - tf.summary.image('observation', self.obs_train) - else: - tf.summary.histogram('observation', self.obs_train) + + if self.full_tensorboard_log: + tf.summary.histogram('rewards', self.rewards) + tf.summary.histogram('param_noise_stddev', self.param_noise_stddev) + if len(self.observation_space.shape) == 3 and self.observation_space.shape[0] in [1, 3, 4]: + tf.summary.image('observation', self.obs_train) + else: + tf.summary.histogram('observation', self.obs_train) with tf.variable_scope("Adam_mpi", reuse=False): self._setup_actor_optimizer() @@ -631,7 +638,7 @@ def _train_step(self, step, writer, log=False): if writer is not None: # run loss backprop with summary if the step_id was not already logged (can happen with the right # parameters as the step value is only an estimate) - if log and step not in self.tb_seen_steps: + if self.full_tensorboard_log and log and step not in self.tb_seen_steps: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, actor_grads, actor_loss, critic_grads, critic_loss = \ @@ -737,8 +744,13 @@ def _reset(self): self.param_noise_stddev: self.param_noise.current_stddev, }) - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DDPG"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DDPG", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) # a list for tensorboard logging, to prevent logging with the same step number, if it already occured @@ -800,9 +812,10 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ ep_rew = np.array([reward]).reshape((1, -1)) ep_done = np.array([done]).reshape((1, -1)) self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_rew, ep_done, - writer, total_steps) + writer, self.num_timesteps) step += 1 total_steps += 1 + self.num_timesteps += 1 if rank == 0 and self.render: self.env.render() episode_reward += reward @@ -814,9 +827,9 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ self._store_transition(obs, action, reward, new_obs, done) obs = new_obs if callback is not None: - # Only stop training if return value is False, not when it is None. This is for backwards - # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + # Only stop training if return value is False, not when it is None. + # This is for backwards compatibility with callbacks that have no return statement. + if callback(locals(), globals()) is False: return self if done: @@ -847,7 +860,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ # weird equation to deal with the fact the nb_train_steps will be different # to nb_rollout_steps step = (int(t_train * (self.nb_rollout_steps / self.nb_train_steps)) + - total_steps - self.nb_rollout_steps) + self.num_timesteps - self.nb_rollout_steps) critic_loss, actor_loss = self._train_step(step, writer, log=t_train == 0) epoch_critic_losses.append(critic_loss) diff --git a/stable_baselines/deepq/build_graph.py b/stable_baselines/deepq/build_graph.py index 02542b568f..b6a9d39589 100644 --- a/stable_baselines/deepq/build_graph.py +++ b/stable_baselines/deepq/build_graph.py @@ -319,8 +319,9 @@ def act(obs, reset=None, update_param_noise_threshold=None, update_param_noise_s return act, obs_phs -def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=None, gamma=1.0, double_q=True, - scope="deepq", reuse=None, param_noise=False, param_noise_filter_func=None): +def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping=None, + gamma=1.0, double_q=True, scope="deepq", reuse=None, + param_noise=False, param_noise_filter_func=None, full_tensorboard_log=False): """ Creates the train function: @@ -340,6 +341,8 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping= :param param_noise_filter_func: (function (TensorFlow Tensor): bool) function that decides whether or not a variable should be perturbed. Only applicable if param_noise is True. If set to None, default_param_noise_filter is used by default. + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly :return: (tuple) @@ -410,9 +413,11 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping= weighted_error = tf.reduce_mean(importance_weights_ph * errors) tf.summary.scalar("td_error", tf.reduce_mean(td_error)) - tf.summary.histogram("td_error", td_error) tf.summary.scalar("loss", weighted_error) + if full_tensorboard_log: + tf.summary.histogram("td_error", td_error) + # update_target_fn will be called periodically to copy Q network to target Q network update_target_expr = [] for var, var_target in zip(sorted(q_func_vars, key=lambda v: v.name), @@ -429,15 +434,15 @@ def build_train(q_func, ob_space, ac_space, optimizer, sess, grad_norm_clipping= with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('rewards', tf.reduce_mean(rew_t_ph)) - tf.summary.histogram('rewards', rew_t_ph) tf.summary.scalar('importance_weights', tf.reduce_mean(importance_weights_ph)) - tf.summary.histogram('importance_weights', importance_weights_ph) - # Valid image: RGB, RGBD, GrayScale - is_image = len(obs_phs[0].shape) == 3 and obs_phs[0].shape[-1] in [1, 3, 4] - if is_image: - tf.summary.image('observation', obs_phs[0]) - elif len(obs_phs[0].shape) == 1: - tf.summary.histogram('observation', obs_phs[0]) + + if full_tensorboard_log: + tf.summary.histogram('rewards', rew_t_ph) + tf.summary.histogram('importance_weights', importance_weights_ph) + if tf_util.is_image(obs_phs[0]): + tf.summary.image('observation', obs_phs[0]) + elif len(obs_phs[0].shape) == 1: + tf.summary.histogram('observation', obs_phs[0]) optimize_expr = optimizer.apply_gradients(gradients) diff --git a/stable_baselines/deepq/dqn.py b/stable_baselines/deepq/dqn.py index edee7f5416..6247fd288b 100644 --- a/stable_baselines/deepq/dqn.py +++ b/stable_baselines/deepq/dqn.py @@ -45,6 +45,8 @@ class DQN(OffPolicyRLModel): :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=50000, exploration_fraction=0.1, @@ -52,7 +54,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 learning_starts=1000, target_network_update_freq=500, prioritized_replay=False, prioritized_replay_alpha=0.6, prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None, prioritized_replay_eps=1e-6, param_noise=False, verbose=0, tensorboard_log=None, - _init_setup_model=True, policy_kwargs=None): + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): # TODO: replay_buffer refactoring super(DQN, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=DQNPolicy, @@ -76,6 +78,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 self.learning_rate = learning_rate self.gamma = gamma self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log self.graph = None self.sess = None @@ -95,6 +98,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=5e-4, buffer_size=5000 self.setup_model() def setup_model(self): + with SetVerbosity(self.verbose): assert not isinstance(self.action_space, gym.spaces.Box), \ "Error: DQN cannot output a gym.spaces.Box action space." @@ -122,7 +126,8 @@ def setup_model(self): gamma=self.gamma, grad_norm_clipping=10, param_noise=self.param_noise, - sess=self.sess + sess=self.sess, + full_tensorboard_log=self.full_tensorboard_log ) self.proba_step = self.step_model.proba_step self.params = find_trainable_variables("deepq") @@ -133,8 +138,13 @@ def setup_model(self): self.summary = tf.summary.merge_all() - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DQN"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="DQN", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) # Create the replay buffer @@ -160,16 +170,16 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ reset = True self.episode_reward = np.zeros((1,)) - for step in range(total_timesteps): + for _ in range(total_timesteps): if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break # Take action and update exploration to the newest value kwargs = {} if not self.param_noise: - update_eps = self.exploration.value(step) + update_eps = self.exploration.value(self.num_timesteps) update_param_noise_threshold = 0. else: update_eps = 0. @@ -178,8 +188,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ # See Appendix C.1 in Parameter Space Noise for Exploration, Plappert et al., 2017 # for detailed explanation. update_param_noise_threshold = \ - -np.log(1. - self.exploration.value(step) + - self.exploration.value(step) / float(self.env.action_space.n)) + -np.log(1. - self.exploration.value(self.num_timesteps) + + self.exploration.value(self.num_timesteps) / float(self.env.action_space.n)) kwargs['reset'] = reset kwargs['update_param_noise_threshold'] = update_param_noise_threshold kwargs['update_param_noise_scale'] = True @@ -196,7 +206,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ ep_rew = np.array([rew]).reshape((1, -1)) ep_done = np.array([done]).reshape((1, -1)) self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_rew, ep_done, writer, - step) + self.num_timesteps) episode_rewards[-1] += rew if done: @@ -205,10 +215,11 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ episode_rewards.append(0.0) reset = True - if step > self.learning_starts and step % self.train_freq == 0: + if self.num_timesteps > self.learning_starts and self.num_timesteps % self.train_freq == 0: # Minimize the error in Bellman's equation on a batch sampled from replay buffer. if self.prioritized_replay: - experience = self.replay_buffer.sample(self.batch_size, beta=self.beta_schedule.value(step)) + experience = self.replay_buffer.sample(self.batch_size, + beta=self.beta_schedule.value(self.num_timesteps)) (obses_t, actions, rewards, obses_tp1, dones, weights, batch_idxes) = experience else: obses_t, actions, rewards, obses_tp1, dones = self.replay_buffer.sample(self.batch_size) @@ -217,17 +228,17 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ if writer is not None: # run loss backprop with summary, but once every 100 steps save the metadata # (memory, compute time, ...) - if (1 + step) % 100 == 0: + if (1 + self.num_timesteps) % 100 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, td_errors = self._train_step(obses_t, actions, rewards, obses_tp1, obses_tp1, dones, weights, sess=self.sess, options=run_options, run_metadata=run_metadata) - writer.add_run_metadata(run_metadata, 'step%d' % step) + writer.add_run_metadata(run_metadata, 'step%d' % self.num_timesteps) else: summary, td_errors = self._train_step(obses_t, actions, rewards, obses_tp1, obses_tp1, dones, weights, sess=self.sess) - writer.add_summary(summary, step) + writer.add_summary(summary, self.num_timesteps) else: _, td_errors = self._train_step(obses_t, actions, rewards, obses_tp1, obses_tp1, dones, weights, sess=self.sess) @@ -236,7 +247,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ new_priorities = np.abs(td_errors) + self.prioritized_replay_eps self.replay_buffer.update_priorities(batch_idxes, new_priorities) - if step > self.learning_starts and step % self.target_network_update_freq == 0: + if self.num_timesteps > self.learning_starts and \ + self.num_timesteps % self.target_network_update_freq == 0: # Update target network periodically. self.update_target(sess=self.sess) @@ -247,12 +259,15 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ num_episodes = len(episode_rewards) if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0: - logger.record_tabular("steps", step) + logger.record_tabular("steps", self.num_timesteps) logger.record_tabular("episodes", num_episodes) logger.record_tabular("mean 100 episode reward", mean_100ep_reward) - logger.record_tabular("% time spent exploring", int(100 * self.exploration.value(step))) + logger.record_tabular("% time spent exploring", + int(100 * self.exploration.value(self.num_timesteps))) logger.dump_tabular() + self.num_timesteps += 1 + return self def predict(self, observation, state=None, mask=None, deterministic=True): diff --git a/stable_baselines/deepq/experiments/train_cartpole.py b/stable_baselines/deepq/experiments/train_cartpole.py index 9cef8df718..9c1aee0e9d 100644 --- a/stable_baselines/deepq/experiments/train_cartpole.py +++ b/stable_baselines/deepq/experiments/train_cartpole.py @@ -19,7 +19,7 @@ def callback(lcl, _glb): mean_100ep_reward = -np.inf else: mean_100ep_reward = round(float(np.mean(lcl['episode_rewards'][-101:-1])), 1) - is_solved = lcl['step'] > 100 and mean_100ep_reward >= 199 + is_solved = lcl['self'].num_timesteps > 100 and mean_100ep_reward >= 199 return not is_solved diff --git a/stable_baselines/gail/model.py b/stable_baselines/gail/model.py index 4af84d7ac1..720802b524 100644 --- a/stable_baselines/gail/model.py +++ b/stable_baselines/gail/model.py @@ -31,6 +31,8 @@ class GAIL(ActorCriticRLModel): :param d_stepsize: (float) the reward giver stepsize :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, pretrained_weight=False, hidden_size_adversary=100, adversary_entcoeff=1e-3, @@ -66,15 +68,16 @@ def setup_model(self): self.trpo.setup_model() - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="GAIL"): - self.trpo.learn(total_timesteps, callback, seed, log_interval, tb_log_name) + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="GAIL", + reset_num_timesteps=True): + self.trpo.learn(total_timesteps, callback, seed, log_interval, tb_log_name, reset_num_timesteps) return self def predict(self, observation, state=None, mask=None, deterministic=False): return self.trpo.predict(observation, state, mask, deterministic=deterministic) - def action_probability(self, observation, state=None, mask=None): - return self.trpo.action_probability(observation, state, mask) + def action_probability(self, observation, state=None, mask=None, actions=None): + return self.trpo.action_probability(observation, state, mask, actions) def save(self, save_path): self.trpo.save(save_path) diff --git a/stable_baselines/her/her.py b/stable_baselines/her/her.py index 87d5b3e1b6..b1b2589b8e 100644 --- a/stable_baselines/her/her.py +++ b/stable_baselines/her/her.py @@ -93,7 +93,8 @@ def setup_model(self): with self.graph.as_default(): pass - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER"): + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="HER", + reset_num_timesteps=True): with SetVerbosity(self.verbose): self._setup_learn(seed) @@ -102,7 +103,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ def predict(self, observation, state=None, mask=None, deterministic=False): pass - def action_probability(self, observation, state=None, mask=None): + def action_probability(self, observation, state=None, mask=None, actions=None): pass def save(self, save_path): diff --git a/stable_baselines/ppo1/pposgd_simple.py b/stable_baselines/ppo1/pposgd_simple.py index 1da9cf3f37..36c5547cbc 100644 --- a/stable_baselines/ppo1/pposgd_simple.py +++ b/stable_baselines/ppo1/pposgd_simple.py @@ -38,11 +38,14 @@ class PPO1(ActorCriticRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, timesteps_per_actorbatch=256, clip_param=0.2, entcoeff=0.01, optim_epochs=4, optim_stepsize=1e-3, optim_batchsize=64, lam=0.95, adam_epsilon=1e-5, - schedule='linear', verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): + schedule='linear', verbose=0, tensorboard_log=None, + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super().__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=False, _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) @@ -58,6 +61,7 @@ def __init__(self, policy, env, gamma=0.99, timesteps_per_actorbatch=256, clip_p self.adam_epsilon = adam_epsilon self.schedule = schedule self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log self.graph = None self.sess = None @@ -148,17 +152,19 @@ def setup_model(self): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('discounted_rewards', tf.reduce_mean(ret)) - tf.summary.histogram('discounted_rewards', ret) tf.summary.scalar('learning_rate', tf.reduce_mean(self.optim_stepsize)) - tf.summary.histogram('learning_rate', self.optim_stepsize) tf.summary.scalar('advantage', tf.reduce_mean(atarg)) - tf.summary.histogram('advantage', atarg) tf.summary.scalar('clip_range', tf.reduce_mean(self.clip_param)) - tf.summary.histogram('clip_range', self.clip_param) - if len(self.observation_space.shape) == 3: - tf.summary.image('observation', obs_ph) - else: - tf.summary.histogram('observation', obs_ph) + + if self.full_tensorboard_log: + tf.summary.histogram('discounted_rewards', ret) + tf.summary.histogram('learning_rate', self.optim_stepsize) + tf.summary.histogram('advantage', atarg) + tf.summary.histogram('clip_range', self.clip_param) + if tf_util.is_image(self.observation_space): + tf.summary.image('observation', obs_ph) + else: + tf.summary.histogram('observation', obs_ph) self.step = self.policy_pi.step self.proba_step = self.policy_pi.proba_step @@ -173,8 +179,13 @@ def setup_model(self): self.compute_losses = tf_util.function([obs_ph, old_pi.obs_ph, action_ph, atarg, ret, lrmult], losses) - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="PPO1"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="PPO1", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) assert issubclass(self.policy, ActorCriticPolicy), "Error: the input policy for the PPO1 model must be " \ @@ -202,7 +213,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break if total_timesteps and timesteps_so_far >= total_timesteps: break @@ -227,7 +238,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ self.episode_reward = total_episode_reward_logger(self.episode_reward, seg["true_rew"].reshape((self.n_envs, -1)), seg["dones"].reshape((self.n_envs, -1)), - writer, timesteps_so_far) + writer, self.num_timesteps) # predicted value function before udpate vpredbefore = seg["vpred"] @@ -248,13 +259,13 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ # list of tuples, each of which gives the loss for a minibatch losses = [] for i, batch in enumerate(dataset.iterate_once(optim_batchsize)): - steps = (timesteps_so_far + + steps = (self.num_timesteps + k * optim_batchsize + int(i * (optim_batchsize / len(dataset.data_map)))) if writer is not None: # run loss backprop with summary, but once every 10 runs save the metadata # (memory, compute time, ...) - if (1 + k) % 10 == 0: + if self.full_tensorboard_log and (1 + k) % 10 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, grad, *newlosses = self.lossandgrad(batch["ob"], batch["ob"], batch["ac"], @@ -302,10 +313,12 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ logger.record_tabular("EpRewMean", np.mean(rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) - timesteps_so_far += MPI.COMM_WORLD.allreduce(seg["total_timestep"]) + current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"]) + timesteps_so_far += current_it_timesteps + self.num_timesteps += current_it_timesteps iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) - logger.record_tabular("TimestepsSoFar", timesteps_so_far) + logger.record_tabular("TimestepsSoFar", self.num_timesteps) logger.record_tabular("TimeElapsed", time.time() - t_start) if self.verbose >= 1 and MPI.COMM_WORLD.Get_rank() == 0: logger.dump_tabular() diff --git a/stable_baselines/ppo2/ppo2.py b/stable_baselines/ppo2/ppo2.py index bdf264e914..432fac2c33 100644 --- a/stable_baselines/ppo2/ppo2.py +++ b/stable_baselines/ppo2/ppo2.py @@ -37,11 +37,14 @@ class PPO2(ActorCriticRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly """ def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning_rate=2.5e-4, vf_coef=0.5, max_grad_norm=0.5, lam=0.95, nminibatches=4, noptepochs=4, cliprange=0.2, verbose=0, - tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): + tensorboard_log=None, _init_setup_model=True, policy_kwargs=None, + full_tensorboard_log=False): super(PPO2, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True, _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) @@ -57,6 +60,7 @@ def __init__(self, policy, env, gamma=0.99, n_steps=128, ent_coef=0.01, learning self.nminibatches = nminibatches self.noptepochs = noptepochs self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log self.graph = None self.sess = None @@ -156,8 +160,9 @@ def setup_model(self): with tf.variable_scope('model'): self.params = tf.trainable_variables() - for var in self.params: - tf.summary.histogram(var.name, var) + if self.full_tensorboard_log: + for var in self.params: + tf.summary.histogram(var.name, var) grads = tf.gradients(loss, self.params) if self.max_grad_norm is not None: grads, _grad_norm = tf.clip_by_global_norm(grads, self.max_grad_norm) @@ -169,21 +174,23 @@ def setup_model(self): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph)) - tf.summary.histogram('discounted_rewards', self.rewards_ph) tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate_ph)) - tf.summary.histogram('learning_rate', self.learning_rate_ph) tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph)) - tf.summary.histogram('advantage', self.advs_ph) tf.summary.scalar('clip_range', tf.reduce_mean(self.clip_range_ph)) - tf.summary.histogram('clip_range', self.clip_range_ph) tf.summary.scalar('old_neglog_action_probabilty', tf.reduce_mean(self.old_neglog_pac_ph)) - tf.summary.histogram('old_neglog_action_probabilty', self.old_neglog_pac_ph) tf.summary.scalar('old_value_pred', tf.reduce_mean(self.old_vpred_ph)) - tf.summary.histogram('old_value_pred', self.old_vpred_ph) - if len(self.observation_space.shape) == 3: - tf.summary.image('observation', train_model.obs_ph) - else: - tf.summary.histogram('observation', train_model.obs_ph) + + if self.full_tensorboard_log: + tf.summary.histogram('discounted_rewards', self.rewards_ph) + tf.summary.histogram('learning_rate', self.learning_rate_ph) + tf.summary.histogram('advantage', self.advs_ph) + tf.summary.histogram('clip_range', self.clip_range_ph) + tf.summary.histogram('old_neglog_action_probabilty', self.old_neglog_pac_ph) + tf.summary.histogram('old_value_pred', self.old_vpred_ph) + if tf_util.is_image(self.observation_space): + tf.summary.image('observation', train_model.obs_ph) + else: + tf.summary.histogram('observation', train_model.obs_ph) self.train_model = train_model self.act_model = act_model @@ -230,7 +237,7 @@ def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, va if writer is not None: # run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...) - if (1 + update) % 10 == 0: + if self.full_tensorboard_log and (1 + update) % 10 == 0: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() summary, policy_loss, value_loss, policy_entropy, approxkl, clipfrac, _ = self.sess.run( @@ -248,12 +255,16 @@ def _train_step(self, learning_rate, cliprange, obs, returns, masks, actions, va return policy_loss, value_loss, policy_entropy, approxkl, clipfrac - def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_log_name="PPO2"): + def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_log_name="PPO2", + reset_num_timesteps=True): # Transform to callable if needed self.learning_rate = get_schedule_fn(self.learning_rate) self.cliprange = get_schedule_fn(self.cliprange) - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) runner = Runner(env=self.env, model=self, n_steps=self.n_steps, gamma=self.gamma, lam=self.lam) @@ -275,18 +286,21 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo ep_info_buf.extend(ep_infos) mb_loss_vals = [] if states is None: # nonrecurrent version + update_fac = self.n_batch // self.nminibatches // self.noptepochs + 1 inds = np.arange(self.n_batch) for epoch_num in range(self.noptepochs): np.random.shuffle(inds) for start in range(0, self.n_batch, batch_size): - timestep = ((update * self.noptepochs * self.n_batch + epoch_num * self.n_batch + start) // - batch_size) + timestep = self.num_timesteps // update_fac + ((self.noptepochs * self.n_batch + epoch_num * + self.n_batch + start) // batch_size) end = start + batch_size mbinds = inds[start:end] slices = (arr[mbinds] for arr in (obs, returns, masks, actions, values, neglogpacs)) mb_loss_vals.append(self._train_step(lr_now, cliprangenow, *slices, writer=writer, update=timestep)) + self.num_timesteps += (self.n_batch * self.noptepochs) // batch_size * update_fac else: # recurrent version + update_fac = self.n_batch // self.nminibatches // self.noptepochs // self.n_steps + 1 assert self.n_envs % self.nminibatches == 0 env_indices = np.arange(self.n_envs) flat_indices = np.arange(self.n_envs * self.n_steps).reshape(self.n_envs, self.n_steps) @@ -294,8 +308,8 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo for epoch_num in range(self.noptepochs): np.random.shuffle(env_indices) for start in range(0, self.n_envs, envs_per_batch): - timestep = ((update * self.noptepochs * self.n_envs + epoch_num * self.n_envs + start) // - envs_per_batch) + timestep = self.num_timesteps // update_fac + ((self.noptepochs * self.n_envs + epoch_num * + self.n_envs + start) // envs_per_batch) end = start + envs_per_batch mb_env_inds = env_indices[start:end] mb_flat_inds = flat_indices[mb_env_inds].ravel() @@ -303,6 +317,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo mb_states = states[mb_env_inds] mb_loss_vals.append(self._train_step(lr_now, cliprangenow, *slices, update=timestep, writer=writer, states=mb_states)) + self.num_timesteps += (self.n_envs * self.noptepochs) // envs_per_batch * update_fac loss_vals = np.mean(mb_loss_vals, axis=0) t_now = time.time() @@ -312,13 +327,13 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo self.episode_reward = total_episode_reward_logger(self.episode_reward, true_reward.reshape((self.n_envs, self.n_steps)), masks.reshape((self.n_envs, self.n_steps)), - writer, update * (self.n_batch + 1)) + writer, self.num_timesteps) if self.verbose >= 1 and (update % log_interval == 0 or update == 1): explained_var = explained_variance(values, returns) logger.logkv("serial_timesteps", update * self.n_steps) logger.logkv("nupdates", update) - logger.logkv("total_timesteps", update * self.n_batch) + logger.logkv("total_timesteps", self.num_timesteps) logger.logkv("fps", fps) logger.logkv("explained_variance", float(explained_var)) logger.logkv('ep_rewmean', safe_mean([ep_info['r'] for ep_info in ep_info_buf])) @@ -331,7 +346,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=1, tb_lo if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break return self diff --git a/stable_baselines/sac/sac.py b/stable_baselines/sac/sac.py index 10f6858030..10b1f30199 100644 --- a/stable_baselines/sac/sac.py +++ b/stable_baselines/sac/sac.py @@ -57,13 +57,16 @@ class SAC(OffPolicyRLModel): :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + Note: this has no effect on SAC logging for now """ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=50000, learning_starts=100, train_freq=1, batch_size=64, tau=0.005, ent_coef='auto', target_update_interval=1, - gradient_steps=1, target_entropy='auto', - verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None): + gradient_steps=1, target_entropy='auto', verbose=0, tensorboard_log=None, + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): + super(SAC, self).__init__(policy=policy, env=env, replay_buffer=None, verbose=verbose, policy_base=SACPolicy, requires_vec_env=False, policy_kwargs=policy_kwargs) @@ -95,6 +98,7 @@ def __init__(self, policy, env, gamma=0.99, learning_rate=3e-4, buffer_size=5000 self.summary = None self.policy_tf = None self.target_entropy = target_entropy + self.full_tensorboard_log = full_tensorboard_log self.obs_target = None self.target_policy = None @@ -341,8 +345,14 @@ def _train_step(self, step, writer, learning_rate): return policy_loss, qf1_loss, qf2_loss, value_loss, entropy - def learn(self, total_timesteps, callback=None, seed=None, log_interval=4, tb_log_name="SAC"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, + log_interval=4, tb_log_name="SAC", reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: + self._setup_learn(seed) # Transform to callable if needed @@ -368,7 +378,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=4, tb_lo # Before training starts, randomly sample actions # from a uniform distribution for better exploration. # Afterwards, use the learned policy. - if step < self.learning_starts: + if self.num_timesteps < self.learning_starts: action = self.env.action_space.sample() # No need to rescale when sampling random action rescaled_action = action @@ -395,13 +405,13 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=4, tb_lo ep_reward = np.array([reward]).reshape((1, -1)) ep_done = np.array([done]).reshape((1, -1)) self.episode_reward = total_episode_reward_logger(self.episode_reward, ep_reward, - ep_done, writer, step) + ep_done, writer, self.num_timesteps) if step % self.train_freq == 0: mb_infos_vals = [] # Update policy, critics and target networks for grad_step in range(self.gradient_steps): - if step < self.batch_size or step < self.learning_starts: + if self.num_timesteps < self.batch_size or self.num_timesteps < self.learning_starts: break n_updates += 1 # Compute current learning_rate @@ -429,6 +439,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=4, tb_lo mean_reward = round(float(np.mean(episode_rewards[-101:-1])), 1) num_episodes = len(episode_rewards) + self.num_timesteps += 1 # Display training infos if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0: fps = int(step / (time.time() - start_time)) @@ -443,7 +454,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=4, tb_lo if len(infos_values) > 0: for (name, val) in zip(self.infos_names, infos_values): logger.logkv(name, val) - logger.logkv("total timesteps", step) + logger.logkv("total timesteps", self.num_timesteps) logger.dumpkvs() # Reset infos: infos_values = [] diff --git a/stable_baselines/trpo_mpi/trpo_mpi.py b/stable_baselines/trpo_mpi/trpo_mpi.py index d1ed648b7a..55176baa0b 100644 --- a/stable_baselines/trpo_mpi/trpo_mpi.py +++ b/stable_baselines/trpo_mpi/trpo_mpi.py @@ -15,32 +15,36 @@ from stable_baselines.common.policies import ActorCriticPolicy from stable_baselines.a2c.utils import find_trainable_variables, total_episode_reward_logger from stable_baselines.trpo_mpi.utils import traj_segment_generator, add_vtarg_and_adv, flatten_lists + + # from stable_baselines.gail.statistics import Stats class TRPO(ActorCriticRLModel): + """ + Trust Region Policy Optimization (https://arxiv.org/abs/1502.05477) + + :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...) + :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) + :param gamma: (float) the discount value + :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon) + :param max_kl: (float) the kullback leiber loss threshold + :param cg_iters: (int) the number of iterations for the conjugate gradient calculation + :param lam: (float) GAE factor + :param entcoeff: (float) the weight for the entropy loss + :param cg_damping: (float) the compute gradient dampening factor + :param vf_stepsize: (float) the value function stepsize + :param vf_iters: (int) the value function's number iterations for learning + :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug + :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) + :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance + :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation + :param full_tensorboard_log: (bool) enable additional logging when using tensorboard + WARNING: this logging can take a lot of space quickly + """ def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.01, cg_iters=10, lam=0.98, entcoeff=0.0, cg_damping=1e-2, vf_stepsize=3e-4, vf_iters=3, verbose=0, tensorboard_log=None, - _init_setup_model=True, policy_kwargs=None): - """ - learns a TRPO policy using the given environment - - :param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, CnnLstmPolicy, ...) - :param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str) - :param gamma: (float) the discount value - :param timesteps_per_batch: (int) the number of timesteps to run per batch (horizon) - :param max_kl: (float) the kullback leiber loss threshold - :param cg_iters: (int) the number of iterations for the conjugate gradient calculation - :param lam: (float) GAE factor - :param entcoeff: (float) the weight for the entropy loss - :param cg_damping: (float) the compute gradient dampening factor - :param vf_stepsize: (float) the value function stepsize - :param vf_iters: (int) the value function's number iterations for learning - :param verbose: (int) the verbosity level: 0 none, 1 training information, 2 tensorflow debug - :param tensorboard_log: (str) the log location for tensorboard (if None, no logging) - :param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance - :param policy_kwargs: (dict) additional arguments to be passed to the policy on creation - """ + _init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False): super(TRPO, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=False, _init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs) @@ -55,6 +59,7 @@ def __init__(self, policy, env, gamma=0.99, timesteps_per_batch=1024, max_kl=0.0 self.vf_stepsize = vf_stepsize self.entcoeff = entcoeff self.tensorboard_log = tensorboard_log + self.full_tensorboard_log = full_tensorboard_log # GAIL Params self.pretrained_weight = None @@ -181,7 +186,7 @@ def setup_model(self): self.assign_old_eq_new = \ tf_util.function([], [], updates=[tf.assign(oldv, newv) for (oldv, newv) in zipsame(tf_util.get_globals_vars("oldpi"), - tf_util.get_globals_vars("model"))]) + tf_util.get_globals_vars("model"))]) self.compute_losses = tf_util.function([observation, old_policy.obs_ph, action, atarg], losses) self.compute_fvp = tf_util.function([flat_tangent, observation, old_policy.obs_ph, action, atarg], fvp) @@ -220,17 +225,19 @@ def allmean(arr): with tf.variable_scope("input_info", reuse=False): tf.summary.scalar('discounted_rewards', tf.reduce_mean(ret)) - tf.summary.histogram('discounted_rewards', ret) tf.summary.scalar('learning_rate', tf.reduce_mean(self.vf_stepsize)) - tf.summary.histogram('learning_rate', self.vf_stepsize) tf.summary.scalar('advantage', tf.reduce_mean(atarg)) - tf.summary.histogram('advantage', atarg) tf.summary.scalar('kl_clip_range', tf.reduce_mean(self.max_kl)) - tf.summary.histogram('kl_clip_range', self.max_kl) - if len(self.observation_space.shape) == 3: - tf.summary.image('observation', observation) - else: - tf.summary.histogram('observation', observation) + + if self.full_tensorboard_log: + tf.summary.histogram('discounted_rewards', ret) + tf.summary.histogram('learning_rate', self.vf_stepsize) + tf.summary.histogram('advantage', atarg) + tf.summary.histogram('kl_clip_range', self.max_kl) + if tf_util.is_image(self.observation_space): + tf.summary.image('observation', observation) + else: + tf.summary.histogram('observation', observation) self.timed = timed self.allmean = allmean @@ -249,8 +256,13 @@ def allmean(arr): tf_util.function([observation, old_policy.obs_ph, action, atarg, ret], [self.summary, tf_util.flatgrad(optimgain, var_list)] + losses) - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="TRPO"): - with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer: + def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="TRPO", + reset_num_timesteps=True): + + new_tb_log = self._init_num_timesteps(reset_num_timesteps) + + with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \ + as writer: self._setup_learn(seed) with self.sess.as_default(): @@ -282,7 +294,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ if callback is not None: # Only stop training if return value is False, not when it is None. This is for backwards # compatibility with callbacks that have no return statement. - if callback(locals(), globals()) == False: + if callback(locals(), globals()) is False: break if total_timesteps and timesteps_so_far >= total_timesteps: break @@ -291,6 +303,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_ def fisher_vector_product(vec): return self.allmean(self.compute_fvp(vec, *fvpargs, sess=self.sess)) + self.cg_damping * vec + # ------------------ Update G ------------------ logger.log("Optimizing Policy...") # g_step = 1 when not using GAIL @@ -315,7 +328,7 @@ def fisher_vector_product(vec): seg["true_rew"].reshape( (self.n_envs, -1)), seg["dones"].reshape((self.n_envs, -1)), - writer, timesteps_so_far) + writer, self.num_timesteps) args = seg["ob"], seg["ob"], seg["ac"], atarg fvpargs = [arr[::5] for arr in args] @@ -323,15 +336,16 @@ def fisher_vector_product(vec): self.assign_old_eq_new(sess=self.sess) with self.timed("computegrad"): - steps = timesteps_so_far + (k + 1) * (seg["total_timestep"] / self.g_step) + steps = self.num_timesteps + (k + 1) * (seg["total_timestep"] / self.g_step) run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) - run_metadata = tf.RunMetadata() + run_metadata = tf.RunMetadata() if self.full_tensorboard_log else None # run loss backprop with summary, and save the metadata (memory, compute time, ...) if writer is not None: summary, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret, sess=self.sess, options=run_options, run_metadata=run_metadata) - writer.add_run_metadata(run_metadata, 'step%d' % steps) + if self.full_tensorboard_log: + writer.add_run_metadata(run_metadata, 'step%d' % steps) writer.add_summary(summary, steps) else: _, grad, *lossbefore = self.compute_lossandgrad(*args, tdlamret, sess=self.sess, @@ -432,11 +446,13 @@ def fisher_vector_product(vec): logger.record_tabular("EpTrueRewMean", np.mean(true_rewbuffer)) logger.record_tabular("EpThisIter", len(lens)) episodes_so_far += len(lens) - timesteps_so_far += seg["total_timestep"] + current_it_timesteps = MPI.COMM_WORLD.allreduce(seg["total_timestep"]) + timesteps_so_far += current_it_timesteps + self.num_timesteps += current_it_timesteps iters_so_far += 1 logger.record_tabular("EpisodesSoFar", episodes_so_far) - logger.record_tabular("TimestepsSoFar", timesteps_so_far) + logger.record_tabular("TimestepsSoFar", self.num_timesteps) logger.record_tabular("TimeElapsed", time.time() - t_start) if self.verbose >= 1 and self.rank == 0: diff --git a/tests/test_distri.py b/tests/test_distri.py index d33b14ecf1..d3be362617 100644 --- a/tests/test_distri.py +++ b/tests/test_distri.py @@ -22,7 +22,7 @@ def test_probtypes(): categorical = CategoricalProbabilityDistributionType(pdparam_categorical.size) validate_probtype(categorical, pdparam_categorical) - nvec = [1, 2, 3] + nvec = np.array([1, 2, 3]) pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1]) multicategorical = MultiCategoricalProbabilityDistributionType(nvec) validate_probtype(multicategorical, pdparam_multicategorical) diff --git a/tests/test_lstm_policy.py b/tests/test_lstm_policy.py index 79e0ae6875..688cf4e1a8 100644 --- a/tests/test_lstm_policy.py +++ b/tests/test_lstm_policy.py @@ -2,7 +2,7 @@ import pytest -from stable_baselines import A2C, ACER, PPO2 +from stable_baselines import A2C, ACER, ACKTR, PPO2 from stable_baselines.common.policies import MlpLstmPolicy, LstmPolicy @@ -33,7 +33,7 @@ def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64, N_TRIALS = 100 -MODELS = [A2C, ACER, PPO2] +MODELS = [A2C, ACER, ACKTR, PPO2] LSTM_POLICIES = [MlpLstmPolicy, CustomLSTMPolicy1, CustomLSTMPolicy2, CustomLSTMPolicy3, CustomLSTMPolicy4] diff --git a/tests/test_tensorboard.py b/tests/test_tensorboard.py new file mode 100644 index 0000000000..debd40cc80 --- /dev/null +++ b/tests/test_tensorboard.py @@ -0,0 +1,33 @@ +import os +import shutil + +import pytest + +from stable_baselines import A2C, ACER, ACKTR, DQN, DDPG, PPO1, PPO2, SAC, TRPO + +TENSORBOARD_DIR = '/tmp/tb_dir/' + +if os.path.isdir(TENSORBOARD_DIR): + shutil.rmtree(TENSORBOARD_DIR) + +MODEL_DICT = { + 'a2c': (A2C, 'CartPole-v1'), + 'acer': (ACER, 'CartPole-v1'), + 'acktr': (ACKTR, 'CartPole-v1'), + 'dqn': (DQN, 'CartPole-v1'), + 'ddpg': (DDPG, 'Pendulum-v0'), + 'ppo1': (PPO1, 'CartPole-v1'), + 'ppo2': (PPO2, 'CartPole-v1'), + 'sac': (SAC, 'Pendulum-v0'), + 'trpo': (TRPO, 'CartPole-v1'), +} + +N_STEPS = 1000 + + +@pytest.mark.parametrize("model_name", MODEL_DICT.keys()) +def test_tensorboard(model_name): + algo, env_id = MODEL_DICT[model_name] + model = algo('MlpPolicy', env_id, verbose=1, tensorboard_log=TENSORBOARD_DIR) + model.learn(N_STEPS) + model.learn(N_STEPS, reset_num_timesteps=False) diff --git a/tests/test_tf_util.py b/tests/test_tf_util.py index ef7f872139..d71374da03 100644 --- a/tests/test_tf_util.py +++ b/tests/test_tf_util.py @@ -1,7 +1,8 @@ # tests for tf_util +import numpy as np import tensorflow as tf -from stable_baselines.common.tf_util import function, initialize, single_threaded_session +from stable_baselines.common.tf_util import function, initialize, single_threaded_session, is_image def test_function(): @@ -38,6 +39,23 @@ def test_multikwargs(): assert linear_fn(2, 2) == 10 -if __name__ == '__main__': - test_function() - test_multikwargs() +def test_image_detection(): + rgb = (32, 64, 3) + gray = (43, 23, 1) + rgbd = (12, 32, 4) + invalid_1 = (32, 12) + invalid_2 = (12, 32, 6) + + # TF checks + for shape in (rgb, gray, rgbd): + assert is_image(tf.placeholder(tf.uint8, shape=shape)) + + for shape in (invalid_1, invalid_2): + assert not is_image(tf.placeholder(tf.uint8, shape=shape)) + + # Numpy checks + for shape in (rgb, gray, rgbd): + assert is_image(np.ones(shape)) + + for shape in (invalid_1, invalid_2): + assert not is_image(np.ones(shape))