Skip to content

Commit

Permalink
cleaning GAIL + doc fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Jul 27, 2018
1 parent a187fe6 commit 3a4dcbd
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 67 deletions.
2 changes: 1 addition & 1 deletion baselines/acktr/acktr_cont.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def rollout(env, policy, max_pathlength, animate=False, obfilter=None):
def learn(env, policy, value_fn, gamma, lam, timesteps_per_batch, num_timesteps,
animate=False, callback=None, desired_kl=0.002):
"""
Learns a Kfac model
Traines an ACKTR model.
:param env: (Gym environment) The environment to learn from
:param policy: (Object) The policy model to use (MLP, CNN, LSTM, ...)
Expand Down
12 changes: 8 additions & 4 deletions baselines/common/tf_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,28 +434,32 @@ def get_available_gpus():
# Saving variables
# ================================================================

def load_state(fname, sess=None):
def load_state(fname, sess=None, var_list=None):
"""
Load a TensorFlow saved model
:param fname: (str) the graph name
:param sess: (TensorFlow Session) the session, if None: get_default_session()
:param var_list: ([TensorFlow Tensor] or {str: TensorFlow Tensor}) A list of Variable/SaveableObject,
or a dictionary mapping names to SaveableObject`s. If `None, defaults to the list of all saveable objects.
"""
if sess is None:
sess = tf.get_default_session()
saver = tf.train.Saver()
saver = tf.train.Saver(var_list=var_list)
saver.restore(sess, fname)


def save_state(fname, sess=None):
def save_state(fname, sess=None, var_list=None):
"""
Save a TensorFlow model
:param fname: (str) the graph name
:param sess: (TensorFlow Session) the session, if None: get_default_session()
:param var_list: ([TensorFlow Tensor] or {str: TensorFlow Tensor}) A list of Variable/SaveableObject,
or a dictionary mapping names to SaveableObject`s. If `None, defaults to the list of all saveable objects.
"""
if sess is None:
sess = tf.get_default_session()
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver = tf.train.Saver(var_list=var_list)
saver.save(sess, fname)
8 changes: 3 additions & 5 deletions baselines/gail/behavior_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4, adam_e
savedir_fname = tempfile.TemporaryDirectory().name
else:
savedir_fname = os.path.join(ckpt_dir, task_name)
# FIXME: Incorrect call argument...
# commented for now
# tf_util.save_state(savedir_fname, var_list=pi.get_variables())
tf_util.save_state(savedir_fname, var_list=policy.get_variables())
return savedir_fname


Expand Down Expand Up @@ -119,8 +117,8 @@ def main(args):
set_global_seeds(args.seed)
env = gym.make(args.env_id)

def policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
def policy_fn(name, ob_space, ac_space, reuse=False, sess=None):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, sess=sess,
reuse=reuse, hid_size=args.policy_hidden_size, num_hid_layers=2)
env = bench.Monitor(env, logger.get_dir() and
os.path.join(logger.get_dir(), "monitor.json"))
Expand Down
4 changes: 2 additions & 2 deletions baselines/gail/gail_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def _get_checkpoint_dir(checkpoint_list, limit, prefix):
return checkpoint
return None

def _policy_fn(name, ob_space, ac_space, reuse=False):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space,
def _policy_fn(name, ob_space, ac_space, reuse=False, sess=None):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, sess=sess,
reuse=reuse, hid_size=policy_hidden_size, num_hid_layers=2)

data_path = os.path.join('data', 'deterministic.trpo.' + env_name + '.0.00.npz')
Expand Down
15 changes: 9 additions & 6 deletions baselines/gail/mlp_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,22 @@
class MlpPolicy(BasePolicy):
recurrent = False

def __init__(self, name, reuse=False, *args, **kwargs):
def __init__(self, name, *args, sess=None, reuse=False, placeholders=None, **kwargs):
"""
MLP policy for Gail
:param name: (str) the variable scope name
:param reuse: (bool) allow resue of the graph
:param ob_space: (Gym Space) the observation space
:param ac_space: (Gym Space) the action space
:param hid_size: (int) the number of hidden neurons for every hidden layer
:param ob_space: (Gym Space) The observation space of the environment
:param ac_space: (Gym Space) The action space of the environment
:param hid_size: (int) the size of the hidden layers
:param num_hid_layers: (int) the number of hidden layers
:param sess: (TensorFlow session) The current TensorFlow session containing the variables.
:param reuse: (bool) allow resue of the graph
:param placeholders: (dict) To feed existing placeholders if needed
:param gaussian_fixed_var: (bool) fix the gaussian variance
"""
super(MlpPolicy, self).__init__()
super(MlpPolicy, self).__init__(placeholders=placeholders)
self.sess = sess
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
Expand Down
4 changes: 2 additions & 2 deletions baselines/gail/run_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def main(args):
set_global_seeds(args.seed)
env = gym.make(args.env_id)

def policy_fn(name, ob_space, ac_space, reuse=False, placeholders=None):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, reuse=reuse,
def policy_fn(name, ob_space, ac_space, reuse=False, placeholders=None, sess=None):
return mlp_policy.MlpPolicy(name=name, ob_space=ob_space, ac_space=ac_space, reuse=reuse, sess=sess,
hid_size=args.policy_hidden_size, num_hid_layers=2, placeholders=placeholders)
env = bench.Monitor(env, logger.get_dir() and
os.path.join(logger.get_dir(), "monitor.json"))
Expand Down
63 changes: 16 additions & 47 deletions baselines/gail/trpo_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,9 @@ def learn(env, policy_func, *, timesteps_per_batch, max_kl, cg_iters, gamma, lam
# ----------------------------------------
ob_space = env.observation_space
ac_space = env.action_space
if using_gail:
policy = policy_func("pi", ob_space, ac_space, reuse=(pretrained_weight is not None))
old_policy = policy_func("oldpi", ob_space, ac_space,
placeholders={"obs": policy.obs_ph, "stochastic": policy.stochastic_ph})
else:
policy = policy_func("pi", ob_space, ac_space, sess=sess)
old_policy = policy_func("oldpi", ob_space, ac_space, sess=sess,
placeholders={"obs": policy.obs_ph, "stochastic": policy.stochastic_ph})
policy = policy_func("pi", ob_space, ac_space, sess=sess)
old_policy = policy_func("oldpi", ob_space, ac_space, sess=sess,
placeholders={"obs": policy.obs_ph, "stochastic": policy.stochastic_ph})

atarg = tf.placeholder(dtype=tf.float32, shape=[None]) # Target advantage function (if applicable)
ret = tf.placeholder(dtype=tf.float32, shape=[None]) # Empirical return
Expand Down Expand Up @@ -216,15 +211,13 @@ def learn(env, policy_func, *, timesteps_per_batch, max_kl, cg_iters, gamma, lam
vf_var_list = [v for v in all_var_list if v.name.startswith("pi/vff")]
assert len(var_list) == len(vf_var_list) + 1
d_adam = MpiAdam(reward_giver.get_trainable_variables())
vfadam = MpiAdam(vf_var_list)
get_flat = tf_util.GetFlat(var_list)
set_from_flat = tf_util.SetFromFlat(var_list)
else:
var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("pol")]
vf_var_list = [v for v in all_var_list if v.name.split("/")[1].startswith("vf")]
vfadam = MpiAdam(vf_var_list, sess=sess)
get_flat = tf_util.GetFlat(var_list, sess=sess)
set_from_flat = tf_util.SetFromFlat(var_list, sess=sess)

vfadam = MpiAdam(vf_var_list, sess=sess)
get_flat = tf_util.GetFlat(var_list, sess=sess)
set_from_flat = tf_util.SetFromFlat(var_list, sess=sess)

klgrads = tf.gradients(dist, var_list)
flat_tangent = tf.placeholder(dtype=tf.float32, shape=[None], name="flat_tan")
Expand Down Expand Up @@ -264,10 +257,7 @@ def allmean(arr):
out /= nworkers
return out

if using_gail:
tf_util.initialize()
else:
tf_util.initialize(sess=sess)
tf_util.initialize(sess=sess)

th_init = get_flat()
MPI.COMM_WORLD.Bcast(th_init, root=0)
Expand Down Expand Up @@ -306,10 +296,7 @@ def allmean(arr):

# if provide pretrained weight
if pretrained_weight is not None:
raise NotImplementedError
# FIXME: Incorrect call argument...
# commented for now
# tf_util.load_state(pretrained_weight, var_list=policy.get_variables())
tf_util.load_state(pretrained_weight, var_list=policy.get_variables())

while True:
if callback:
Expand All @@ -326,18 +313,12 @@ def allmean(arr):
fname = os.path.join(ckpt_dir, task_name)
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(tf.get_default_session(), fname)
saver.save(sess, fname)

logger.log("********** Iteration %i ************" % iters_so_far)

# TODO: Add session everywhere for GAIL
# so we can remove duplicated code
if using_gail:
def fisher_vector_product(vec):
return allmean(compute_fvp(vec, *fvpargs)) + cg_damping * vec
else:
def fisher_vector_product(vec):
return allmean(compute_fvp(vec, *fvpargs, sess=sess)) + cg_damping * vec
def fisher_vector_product(vec):
return allmean(compute_fvp(vec, *fvpargs, sess=sess)) + cg_damping * vec
# ------------------ Update G ------------------
logger.log("Optimizing Policy...")
# g_step = 1 when not using GAIL
Expand All @@ -358,16 +339,10 @@ def fisher_vector_product(vec):
args = seg["ob"], seg["ac"], atarg
fvpargs = [arr[::5] for arr in args]

if using_gail:
assign_old_eq_new() # set old parameter values to new parameter values
else:
assign_old_eq_new(sess=sess)
assign_old_eq_new(sess=sess)

with timed("computegrad"):
if using_gail:
*lossbefore, grad = compute_lossandgrad(*args)
else:
*lossbefore, grad = compute_lossandgrad(*args, sess=sess)
*lossbefore, grad = compute_lossandgrad(*args, sess=sess)
lossbefore = allmean(np.array(lossbefore))
grad = allmean(grad)
if np.allclose(grad, 0):
Expand All @@ -388,10 +363,7 @@ def fisher_vector_product(vec):
for _ in range(10):
thnew = thbefore + fullstep * stepsize
set_from_flat(thnew)
if using_gail:
mean_losses = surr, kl_loss, *_ = allmean(np.array(compute_losses(*args)))
else:
mean_losses = surr, kl_loss, *_ = allmean(np.array(compute_losses(*args, sess=sess)))
mean_losses = surr, kl_loss, *_ = allmean(np.array(compute_losses(*args, sess=sess)))
improve = surr - surrbefore
logger.log("Expected: %.3f Actual: %.3f" % (expectedimprove, improve))
if not np.isfinite(mean_losses).all():
Expand All @@ -417,10 +389,7 @@ def fisher_vector_product(vec):
include_final_partial_batch=False, batch_size=128):
if hasattr(policy, "ob_rms"):
policy.ob_rms.update(mbob) # update running mean/std for policy
if using_gail:
grad = allmean(compute_vflossandgrad(mbob, mbret))
else:
grad = allmean(compute_vflossandgrad(mbob, mbret, sess=sess))
grad = allmean(compute_vflossandgrad(mbob, mbret, sess=sess))
vfadam.update(grad, vf_stepsize)

for (loss_name, loss_val) in zip(loss_names, mean_losses):
Expand Down

0 comments on commit 3a4dcbd

Please sign in to comment.