Skip to content

Commit

Permalink
Enhancements (Tensorboard and ACKTR recurrent policies) (#185)
Browse files Browse the repository at this point in the history
* initial tensorboard timestep fix (ppo2) #56

* finished Tensorboard fix

* fixed KFAC when reshapes are in the policy #70

* fixed test callback variable name

* Refactor num_timesteps reset + update changelog
+ update callback check

* Fix redeclaration of total_timesteps

* Reduce default tensorboard log size

* Fix writing metadata when tb full log disabled

* Enable LSTM tests for ACKTR + update doc

* Fixes for new gym version (breaking changes!)

* Add test for tensorboard + fix for DDPG logging

* Fix travis build

* Fix tb test style

* Convert tabs to spaces

* Attempt to fix tf dependency

* Fix typo

* Update changelog

* Update version name
  • Loading branch information
araffin authored Feb 9, 2019
1 parent d997d65 commit 5acf88f
Show file tree
Hide file tree
Showing 26 changed files with 468 additions and 213 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ All the following examples can be executed online using Google colab notebooks:
| ------------------- | ---------------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| ACER | :heavy_check_mark: | :heavy_check_mark: | :x: <sup>(5)</sup> | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| ACKTR | :heavy_check_mark: | :x: | :x: <sup>(5)</sup> | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| ACKTR | :heavy_check_mark: | :heavy_check_mark: | :x: <sup>(5)</sup> | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| DDPG | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :x: |
| DQN | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | :x: | :x: | :x: |
| GAIL <sup>(2)</sup> | :heavy_check_mark: | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: <sup>(4)</sup> |
Expand Down
6 changes: 6 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@ Pre-Release 2.4.1a (WIP)
--------------------------

- fixed computation of training metrics in TRPO and PPO1
- added ``reset_num_timesteps`` keyword when calling train() to continue tensorboard learning curves
- reduced the size taken by tensorboard logs (added a ``full_tensorboard_log`` to enable full logging, which was the previous behavior)
- fixed image detection for tensorboard logging
- fixed ACKTR for recurrent policies
- fixed gym breaking changes
- fixed custom policy examples in the doc for DQN and DDPG
- remove gym spaces patch for equality functions
- fixed tensorflow dependency: cpu version was installed overwritting tensorflow-gpu when present.


Release 2.4.0 (2019-01-17)
Expand Down
32 changes: 28 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from setuptools import setup, find_packages
import sys
import subprocess
from setuptools import setup, find_packages
from distutils.version import LooseVersion

if sys.version_info.major != 3:
print('This Python is only compatible with Python 3, but you are running '
'Python {}. The installation will likely fail.'.format(sys.version_info.major))

# Check tensorflow installation to avoid
# breaking pre-installed tf gpu
install_tf, tf_gpu = False, False
try:
import tensorflow as tf
if tf.__version__ < LooseVersion('1.5.0'):
install_tf = True
# check if a gpu version is needed
tf_gpu = tf.test.is_gpu_available()
except ImportError:
install_tf = True
# Check if a nvidia gpu is present
for command in ['nvidia-smi', '/usr/bin/nvidia-smi', 'nvidia-smi.exe']:
if subprocess.call([command]) == 0:
tf_gpu = True
break

tf_dependency = []
if install_tf:
tf_dependency = ['tensorflow-gpu>=1.5.0'] if tf_gpu else ['tensorflow>=1.5.0']
if tf_gpu:
print("A GPU was detected, tensorflow-gpu will be installed")


long_description = """
[![Build Status](https://travis-ci.com/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.com/hill-a/stable-baselines) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines.readthedocs.io/en/master/?badge=master) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=hill-a/stable-baselines&amp;utm_campaign=Badge_Grade) [![Codacy Badge](https://api.codacy.com/project/badge/Coverage/3bcb4cd6d76a4270acb16b5fe6dd9efa)](https://www.codacy.com/app/baselines_janitors/stable-baselines?utm_source=github.com&utm_medium=referral&utm_content=hill-a/stable-baselines&utm_campaign=Badge_Coverage)
Expand Down Expand Up @@ -83,15 +108,14 @@
'progressbar2',
'mpi4py',
'cloudpickle>=0.5.5',
'tensorflow>=1.5.0',
'click',
'opencv-python',
'numpy',
'pandas',
'matplotlib',
'seaborn',
'glob2'
],
] + tf_dependency,
extras_require={
'tests': [
'pytest==3.5.1',
Expand All @@ -112,7 +136,7 @@
license="MIT",
long_description=long_description,
long_description_content_type='text/markdown',
version="2.4.1a",
version="2.4.1a0",
)

# python setup.py sdist
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from stable_baselines.trpo_mpi import TRPO
from stable_baselines.sac import SAC

__version__ = "2.4.1a"
__version__ = "2.4.1a0"
43 changes: 27 additions & 16 deletions stable_baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class A2C(ActorCriticRLModel):
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
(used only for loading)
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param full_tensorboard_log: (bool) enable additional logging when using tensorboard
WARNING: this logging can take a lot of space quickly
"""

def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.01, max_grad_norm=0.5,
learning_rate=7e-4, alpha=0.99, epsilon=1e-5, lr_schedule='linear', verbose=0, tensorboard_log=None,
_init_setup_model=True, policy_kwargs=None):
_init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False):

super(A2C, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)
Expand All @@ -54,6 +56,7 @@ def __init__(self, policy, env, gamma=0.99, n_steps=5, vf_coef=0.25, ent_coef=0.
self.lr_schedule = lr_schedule
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.full_tensorboard_log = full_tensorboard_log

self.graph = None
self.sess = None
Expand Down Expand Up @@ -132,15 +135,16 @@ def setup_model(self):

with tf.variable_scope("input_info", reuse=False):
tf.summary.scalar('discounted_rewards', tf.reduce_mean(self.rewards_ph))
tf.summary.histogram('discounted_rewards', self.rewards_ph)
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate))
tf.summary.histogram('learning_rate', self.learning_rate)
tf.summary.scalar('advantage', tf.reduce_mean(self.advs_ph))
tf.summary.histogram('advantage', self.advs_ph)
if len(self.observation_space.shape) == 3:
tf.summary.image('observation', train_model.obs_ph)
else:
tf.summary.histogram('observation', train_model.obs_ph)
if self.full_tensorboard_log:
tf.summary.histogram('discounted_rewards', self.rewards_ph)
tf.summary.histogram('learning_rate', self.learning_rate)
tf.summary.histogram('advantage', self.advs_ph)
if tf_util.is_image(self.observation_space):
tf.summary.image('observation', train_model.obs_ph)
else:
tf.summary.histogram('observation', train_model.obs_ph)

trainer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_ph, decay=self.alpha,
epsilon=self.epsilon)
Expand Down Expand Up @@ -184,7 +188,7 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ

if writer is not None:
# run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...)
if (1 + update) % 10 == 0:
if self.full_tensorboard_log and (1 + update) % 10 == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, policy_loss, value_loss, policy_entropy, _ = self.sess.run(
Expand All @@ -202,8 +206,13 @@ def _train_step(self, obs, states, rewards, masks, actions, values, update, writ

return policy_loss, value_loss, policy_entropy

def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="A2C"):
with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer:
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="A2C",
reset_num_timesteps=True):

new_tb_log = self._init_num_timesteps(reset_num_timesteps)

with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
as writer:
self._setup_learn(seed)

self.learning_rate_schedule = Scheduler(initial_value=self.learning_rate, n_values=total_timesteps,
Expand All @@ -216,27 +225,29 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
for update in range(1, total_timesteps // self.n_batch + 1):
# true_reward is the reward without discount
obs, states, rewards, masks, actions, values, true_reward = runner.run()
_, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values, update,
writer)
_, value_loss, policy_entropy = self._train_step(obs, states, rewards, masks, actions, values,
self.num_timesteps // (self.n_batch + 1), writer)
n_seconds = time.time() - t_start
fps = int((update * self.n_batch) / n_seconds)

if writer is not None:
self.episode_reward = total_episode_reward_logger(self.episode_reward,
true_reward.reshape((self.n_envs, self.n_steps)),
masks.reshape((self.n_envs, self.n_steps)),
writer, update * (self.n_batch + 1))
writer, self.num_timesteps)

self.num_timesteps += self.n_batch + 1

if callback is not None:
# Only stop training if return value is False, not when it is None. This is for backwards
# compatibility with callbacks that have no return statement.
if callback(locals(), globals()) == False:
if callback(locals(), globals()) is False:
break

if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
explained_var = explained_variance(values, rewards)
logger.record_tabular("nupdates", update)
logger.record_tabular("total_timesteps", update * self.n_batch)
logger.record_tabular("total_timesteps", self.num_timesteps)
logger.record_tabular("fps", fps)
logger.record_tabular("policy_entropy", float(policy_entropy))
logger.record_tabular("value_loss", float(value_loss))
Expand Down
50 changes: 32 additions & 18 deletions stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,15 @@ class ACER(ActorCriticRLModel):
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
:param policy_kwargs: (dict) additional arguments to be passed to the policy on creation
:param full_tensorboard_log: (bool) enable additional logging when using tensorboard
WARNING: this logging can take a lot of space quickly
"""

def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5, ent_coef=0.01, max_grad_norm=10,
learning_rate=7e-4, lr_schedule='linear', rprop_alpha=0.99, rprop_epsilon=1e-5, buffer_size=5000,
replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True, alpha=0.99, delta=1,
verbose=0, tensorboard_log=None, _init_setup_model=True, policy_kwargs=None):
replay_ratio=4, replay_start=1000, correction_term=10.0, trust_region=True,
alpha=0.99, delta=1, verbose=0, tensorboard_log=None,
_init_setup_model=True, policy_kwargs=None, full_tensorboard_log=False):

super(ACER, self).__init__(policy=policy, env=env, verbose=verbose, requires_vec_env=True,
_init_setup_model=_init_setup_model, policy_kwargs=policy_kwargs)
Expand All @@ -119,6 +122,7 @@ def __init__(self, policy, env, gamma=0.99, n_steps=20, num_procs=1, q_coef=0.5,
self.lr_schedule = lr_schedule
self.num_procs = num_procs
self.tensorboard_log = tensorboard_log
self.full_tensorboard_log = full_tensorboard_log

self.graph = None
self.sess = None
Expand Down Expand Up @@ -361,17 +365,19 @@ def custom_getter(getter, name, *args, **kwargs):

with tf.variable_scope("input_info", reuse=False):
tf.summary.scalar('rewards', tf.reduce_mean(self.reward_ph))
tf.summary.histogram('rewards', self.reward_ph)
tf.summary.scalar('learning_rate', tf.reduce_mean(self.learning_rate))
tf.summary.histogram('learning_rate', self.learning_rate)
tf.summary.scalar('advantage', tf.reduce_mean(adv))
tf.summary.histogram('advantage', adv)
tf.summary.scalar('action_probabilty', tf.reduce_mean(self.mu_ph))
tf.summary.histogram('action_probabilty', self.mu_ph)
if len(self.observation_space.shape) == 3:
tf.summary.image('observation', train_model.obs_ph)
else:
tf.summary.histogram('observation', train_model.obs_ph)

if self.full_tensorboard_log:
tf.summary.histogram('rewards', self.reward_ph)
tf.summary.histogram('learning_rate', self.learning_rate)
tf.summary.histogram('advantage', adv)
tf.summary.histogram('action_probabilty', self.mu_ph)
if tf_util.is_image(self.observation_space):
tf.summary.image('observation', train_model.obs_ph)
else:
tf.summary.histogram('observation', train_model.obs_ph)

trainer = tf.train.RMSPropOptimizer(learning_rate=self.learning_rate_ph, decay=self.rprop_alpha,
epsilon=self.rprop_epsilon)
Expand Down Expand Up @@ -429,7 +435,7 @@ def _train_step(self, obs, actions, rewards, dones, mus, states, masks, steps, w

if writer is not None:
# run loss backprop with summary, but once every 10 runs save the metadata (memory, compute time, ...)
if (1 + (steps / self.n_batch)) % 10 == 0:
if self.full_tensorboard_log and (1 + (steps / self.n_batch)) % 10 == 0:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
step_return = self.sess.run([self.summary] + self.run_ops, td_map, options=run_options,
Expand All @@ -444,8 +450,13 @@ def _train_step(self, obs, actions, rewards, dones, mus, states, masks, steps, w

return self.names_ops, step_return[1:] # strip off _train

def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="ACER"):
with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name) as writer:
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="ACER",
reset_num_timesteps=True):

new_tb_log = self._init_num_timesteps(reset_num_timesteps)

with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
as writer:
self._setup_learn(seed)

self.learning_rate_schedule = Scheduler(initial_value=self.learning_rate, n_values=total_timesteps,
Expand Down Expand Up @@ -474,7 +485,7 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
self.episode_reward = total_episode_reward_logger(self.episode_reward,
rewards.reshape((self.n_envs, self.n_steps)),
dones.reshape((self.n_envs, self.n_steps)),
writer, steps)
writer, self.num_timesteps)

# reshape stuff correctly
obs = obs.reshape(runner.batch_ob_shape)
Expand All @@ -485,16 +496,16 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
masks = masks.reshape([runner.batch_ob_shape[0]])

names_ops, values_ops = self._train_step(obs, actions, rewards, dones, mus, self.initial_state, masks,
steps, writer)
self.num_timesteps, writer)

if callback is not None:
# Only stop training if return value is False, not when it is None. This is for backwards
# compatibility with callbacks that have no return statement.
if callback(locals(), globals()) == False:
if callback(locals(), globals()) is False:
break

if self.verbose >= 1 and (int(steps / runner.n_batch) % log_interval == 0):
logger.record_tabular("total_timesteps", steps)
logger.record_tabular("total_timesteps", self.num_timesteps)
logger.record_tabular("fps", int(steps / (time.time() - t_start)))
# IMP: In EpisodicLife env, during training, we get done=True at each loss of life,
# not just at the terminal state. Thus, this is mean until end of life, not end of episode.
Expand All @@ -519,7 +530,10 @@ def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_
dones = dones.reshape([runner.n_batch])
masks = masks.reshape([runner.batch_ob_shape[0]])

self._train_step(obs, actions, rewards, dones, mus, self.initial_state, masks, steps)
self._train_step(obs, actions, rewards, dones, mus, self.initial_state, masks,
self.num_timesteps)

self.num_timesteps += self.n_batch

return self

Expand Down
Loading

0 comments on commit 5acf88f

Please sign in to comment.