From 4e708fdb4f975e0e3d82d30e126e75dba38cb998 Mon Sep 17 00:00:00 2001 From: Mishari Date: Tue, 27 Oct 2020 13:28:14 -0600 Subject: [PATCH] Add PER --- examples/torch/dqn_atari.py | 16 +- src/garage/envs/gym_env.py | 1 - src/garage/replay_buffer/__init__.py | 3 +- src/garage/replay_buffer/per_replay_buffer.py | 145 ++++++++++++++++++ src/garage/torch/algos/dqn.py | 39 ++++- .../replay_buffer/test_per_replay_buffer.py | 120 +++++++++++++++ 6 files changed, 311 insertions(+), 13 deletions(-) create mode 100644 src/garage/replay_buffer/per_replay_buffer.py create mode 100644 tests/garage/replay_buffer/test_per_replay_buffer.py diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index bc2707d5b0..cf4e740c91 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -23,7 +23,7 @@ from garage.envs.wrappers.stack_frames import StackFrames from garage.experiment.deterministic import set_seed from garage.np.exploration_policies import EpsilonGreedyPolicy -from garage.replay_buffer import PathBuffer +from garage.replay_buffer import PERReplayBuffer from garage.sampler import FragmentWorker, LocalSampler from garage.torch import set_gpu_mode from garage.torch.algos import DQN @@ -40,6 +40,9 @@ n_train_steps=125, target_update_freq=2, buffer_batch_size=32, + double_q=True, + per_beta_init=0.4, + per_alpha=0.6, max_epsilon=1.0, min_epsilon=0.01, decay_ratio=0.1, @@ -104,7 +107,7 @@ def main(env=None, # pylint: disable=unused-argument -@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30) +@wrap_experiment(snapshot_mode='none') def dqn_atari(ctxt=None, env=None, seed=24, @@ -150,8 +153,12 @@ def dqn_atari(ctxt=None, steps_per_epoch = hyperparams['steps_per_epoch'] sampler_batch_size = hyperparams['sampler_batch_size'] num_timesteps = n_epochs * steps_per_epoch * sampler_batch_size - replay_buffer = PathBuffer( - capacity_in_transitions=hyperparams['buffer_size']) + + replay_buffer = PERReplayBuffer(hyperparams['buffer_size'], + num_timesteps, + env.spec, + alpha=hyperparams['per_alpha'], + beta_init=hyperparams['per_beta_init']) qf = DiscreteCNNQFunction( env_spec=env.spec, @@ -179,6 +186,7 @@ def dqn_atari(ctxt=None, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=hyperparams['lr'], + double_q=hyperparams['double_q'], clip_gradient=hyperparams['clip_gradient'], discount=hyperparams['discount'], min_buffer_size=hyperparams['min_buffer_size'], diff --git a/src/garage/envs/gym_env.py b/src/garage/envs/gym_env.py index 321fe0ecaa..4d5cf8f40a 100644 --- a/src/garage/envs/gym_env.py +++ b/src/garage/envs/gym_env.py @@ -13,7 +13,6 @@ # entry points don't close their viewer windows. KNOWN_GYM_NOT_CLOSE_VIEWER = [ # Please keep alphabetized - 'gym.envs.atari', 'gym.envs.box2d', 'gym.envs.classic_control' ] diff --git a/src/garage/replay_buffer/__init__.py b/src/garage/replay_buffer/__init__.py index c69c5a3780..6d0e8daa11 100644 --- a/src/garage/replay_buffer/__init__.py +++ b/src/garage/replay_buffer/__init__.py @@ -4,6 +4,7 @@ """ from garage.replay_buffer.her_replay_buffer import HERReplayBuffer from garage.replay_buffer.path_buffer import PathBuffer +from garage.replay_buffer.per_replay_buffer import PERReplayBuffer from garage.replay_buffer.replay_buffer import ReplayBuffer -__all__ = ['ReplayBuffer', 'HERReplayBuffer', 'PathBuffer'] +__all__ = ['PERReplayBuffer', 'ReplayBuffer', 'HERReplayBuffer', 'PathBuffer'] diff --git a/src/garage/replay_buffer/per_replay_buffer.py b/src/garage/replay_buffer/per_replay_buffer.py new file mode 100644 index 0000000000..de0aa02b55 --- /dev/null +++ b/src/garage/replay_buffer/per_replay_buffer.py @@ -0,0 +1,145 @@ +"""Prioritized Experience Replay.""" + +import numpy as np + +from garage import StepType, TimeStepBatch +from garage.replay_buffer.path_buffer import PathBuffer + + +class PERReplayBuffer(PathBuffer): + """Replay buffer for PER (Prioritized Experience Replay). + + PER assigns priorities to transitions in the buffer. Typically + these priority of each transition is proportional to the corresponding + loss computed at each update step. The priorities are then used to create + a probability distribution when sampling such that higher priority + transitions are sampled more frequently. For more see + https://arxiv.org/abs/1511.05952. + + Args: + capacity_in_transitions (int): total size of transitions in the buffer. + env_spec (EnvSpec): Environment specification. + total_timesteps (int): Total timesteps the experiment will run for. + This is used to calculate the beta parameter when sampling. + alpha (float): hyperparameter that controls the degree of + prioritization. Typically between [0, 1], where 0 corresponds to + no prioritization (uniform sampling). + beta_init (float): Initial value of beta exponent in importance + sampling. Beta is linearly annealed from beta_init to 1 + over total_timesteps. + """ + + def __init__(self, + capacity_in_transitions, + total_timesteps, + env_spec, + alpha=0.6, + beta_init=0.5): + self._alpha = alpha + self._beta_init = beta_init + self._total_timesteps = total_timesteps + self._curr_timestep = 0 + self._priorities = np.zeros((capacity_in_transitions, ), np.float32) + self._rng = np.random.default_rng() + super().__init__(capacity_in_transitions, env_spec) + + def sample_timesteps(self, batch_size): + """Sample a batch of timesteps from the buffer. + + Args: + batch_size (int): Number of timesteps to sample. + + Returns: + TimeStepBatch: The batch of timesteps. + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. + + """ + samples, w, idx = self.sample_transitions(batch_size) + step_types = np.array([ + StepType.TERMINAL if terminal else StepType.MID + for terminal in samples['terminals'].reshape(-1) + ], + dtype=StepType) + return TimeStepBatch(env_spec=self._env_spec, + observations=samples['observations'], + actions=samples['actions'], + rewards=samples['rewards'], + next_observations=samples['next_observations'], + step_types=step_types, + env_infos={}, + agent_infos={}), w, idx + + def sample_transitions(self, batch_size): + """Sample a batch of transitions from the buffer. + + Args: + batch_size (int): Number of transitions to sample. + + Returns: + dict: A dict of arrays of shape (batch_size, flat_dim). + np.ndarray: Weights of the timesteps. + np.ndarray: Indices of sampled timesteps + in the replay buffer. + + """ + priorities = self._priorities + if self._transitions_stored < self._capacity: + priorities = self._priorities[:self._transitions_stored] + probs = priorities**self._alpha + probs /= probs.sum() + idx = self._rng.choice(self._transitions_stored, + size=batch_size, + p=probs) + + beta = self._beta_init + self._curr_timestep * ( + 1.0 - self._beta_init) / self._total_timesteps + beta = min(1.0, beta) + transitions = { + key: buf_arr[idx] + for key, buf_arr in self._buffer.items() + } + + w = (self._transitions_stored * probs[idx])**(-beta) + w /= w.max() + w = np.array(w) + + return transitions, w, idx + + def update_priorities(self, indices, priorities): + """Update priorities of timesteps. + + Args: + indices (np.ndarray): Array of indices corresponding to the + timesteps/priorities to update. + priorities (list[float]): new priorities to set. + + """ + for idx, priority in zip(indices, priorities): + self._priorities[int(idx)] = priority + + def add_path(self, path): + """Add a path to the buffer. + + This differs from the underlying buffer's add_path method + in that the priorities for the new samples are set to + the maximum of all priorities in the buffer. + + Args: + path (dict): A dict of array of shape (path_len, flat_dim). + + """ + path_len = len(path['observations']) + self._curr_timestep += path_len + + # find the indices where the path will be stored + first_seg, second_seg = self._next_path_segments(path_len) + + # set priorities for new timesteps = max(self._priorities) + # or 1 if buffer is empty + max_priority = self._priorities.max() or 1. + self._priorities[list(first_seg)] = max_priority + if second_seg != range(0, 0): + self._priorities[list(second_seg)] = max_priority + super().add_path(path) diff --git a/src/garage/torch/algos/dqn.py b/src/garage/torch/algos/dqn.py index 5419c55218..d3999cd429 100644 --- a/src/garage/torch/algos/dqn.py +++ b/src/garage/torch/algos/dqn.py @@ -10,8 +10,9 @@ from garage import _Default, log_performance, make_optimizer from garage._functions import obtain_evaluation_episodes from garage.np.algos import RLAlgorithm +from garage.replay_buffer import PERReplayBuffer from garage.sampler import FragmentWorker -from garage.torch import global_device, np_to_torch +from garage.torch import global_device, np_to_torch, torch_to_np class DQN(RLAlgorithm): @@ -122,6 +123,9 @@ def __init__( self._qf_optimizer = make_optimizer(qf_optimizer, module=self._qf, lr=qf_lr) + + self._prioritized_replay = isinstance(self.replay_buffer, + PERReplayBuffer) self._eval_env = eval_env def train(self, trainer): @@ -192,10 +196,18 @@ def _train_once(self, itr, episodes): for _ in range(self._n_train_steps): if (self.replay_buffer.n_transitions_stored >= self._min_buffer_size): - timesteps = self.replay_buffer.sample_timesteps( - self._buffer_batch_size) - qf_loss, y, q = tuple(v.cpu().numpy() - for v in self._optimize_qf(timesteps)) + if self._prioritized_replay: + timesteps, weights, indices = ( + self.replay_buffer.sample_timesteps( + self._buffer_batch_size)) + qf_loss, y, q = tuple(v.cpu().numpy() + for v in self._optimize_qf( + timesteps, weights, indices)) + else: + timesteps = self.replay_buffer.sample_timesteps( + self._buffer_batch_size) + qf_loss, y, q = tuple( + v.cpu().numpy() for v in self._optimize_qf(timesteps)) self._episode_qf_losses.append(qf_loss) self._epoch_ys.append(y) @@ -228,11 +240,16 @@ def _log_eval_results(self, epoch): tabular.record('QFunction/AverageAbsY', np.mean(np.abs(self._epoch_ys))) - def _optimize_qf(self, timesteps): + def _optimize_qf(self, timesteps, weights=None, indices=None): """Perform algorithm optimizing. Args: timesteps (TimeStepBatch): Processed batch data. + weights (np.ndarray[float]): Weights used by PER when updating + the network. Should be None if PER is not being used. + indices (list[int or float]): Indices of the sampled + timesteps in the replay buffer. Should be None + if PER is not being used. Returns: qval_loss: Loss of Q-value predicted by the Q-network. @@ -274,7 +291,15 @@ def _optimize_qf(self, timesteps): # optimize qf qvals = self._qf(inputs) selected_qs = torch.sum(qvals * actions, axis=1) - qval_loss = F.smooth_l1_loss(selected_qs, y_target) + qval_loss = F.smooth_l1_loss(selected_qs, y_target, reduction='none') + + if self._prioritized_replay: + qval_loss *= np_to_torch(weights) + priorities = qval_loss + 1e-5 # offset to avoid 0 priorities + priorities = torch_to_np(priorities.data.cpu()) + self.replay_buffer.update_priorities(indices, priorities) + + qval_loss = qval_loss.mean() self._qf_optimizer.zero_grad() qval_loss.backward() diff --git a/tests/garage/replay_buffer/test_per_replay_buffer.py b/tests/garage/replay_buffer/test_per_replay_buffer.py new file mode 100644 index 0000000000..1f47f12529 --- /dev/null +++ b/tests/garage/replay_buffer/test_per_replay_buffer.py @@ -0,0 +1,120 @@ +import akro +import numpy as np +import pytest + +from garage import EnvSpec, EpisodeBatch, StepType +from garage.replay_buffer import PERReplayBuffer + +from tests.fixtures.envs.dummy import DummyDiscreteEnv + + +@pytest.fixture +def setup(): + obs_space = akro.Box(low=1, high=np.inf, shape=(1, ), dtype=np.float32) + act_space = akro.Discrete(1) + env_spec = EnvSpec(obs_space, act_space) + buffer = PERReplayBuffer(100, 100, env_spec) + return buffer, DummyDiscreteEnv() + + +def test_add_path(setup): + buff, env = setup + obs = env.reset() + buff.add_path({'observations': np.array([obs for _ in range(5)])}) + + # initial priorities for inserted timesteps should be 1 + assert (buff._priorities[:5] == 1.).all() + assert (buff._priorities[5:] == 0.).all() + + # test case where buffer is full and paths are split + # into two segments + num_obs = buff._capacity - buff._transitions_stored + buff.add_path( + {'observations': np.array([obs for _ in range(num_obs - 1)])}) + + # artificially set the priority of a transition to be high . + # the next path added to the buffer should wrap around the buffer + # and contain one timestep at the end and 5 at the beginning, all + # of which should have priority == max(buff._priorities). + buff._priorities[-1] = 100. + buff.add_path({'observations': np.array([obs for _ in range(6)])}) + + assert buff._priorities[-1] == 100. + assert (buff._priorities[:5] == 100.).all() + + +def test_update_priorities(setup): + buff, env = setup + obs = env.reset() + buff.add_path({'observations': np.array([obs for _ in range(5)])}) + + assert (buff._priorities[:5] == 1.).all() + assert (buff._priorities[5:] == 0.).all() + + indices = list(range(2, 10)) + new_priorities = [0.5 for _ in range(2, 10)] + buff.update_priorities(indices, new_priorities) + + assert (buff._priorities[2:10] == 0.5).all() + assert (buff._priorities[:2] != 0.5).all() + assert (buff._priorities[10:] != 0.5).all() + + +@pytest.mark.parametrize('alpha, beta_init', [(0.5, 0.5), (0.4, 0.6), + (0.1, 0.9)]) +def test_sample_transitions(setup, alpha, beta_init): + buff, env = setup + obs = env.reset() + buff.add_path({ + 'observations': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + }) + + buff._beta_init = beta_init + buff._alpha = alpha + transitions, weights, indices = buff.sample_transitions(50) + obses = transitions['observations'] + + # verify the indices returned correspond to the correct samples + for o, i in zip(obses, indices): + assert (o == i).all() + + # verify the weights are correct + probs = buff._priorities**buff._alpha + probs /= probs.sum() + + beta = buff._beta_init + 50 * (1.0 - buff._beta_init) / 100 + beta = min(1.0, beta) + w = (50 * probs[indices])**(-beta) + w /= w.max() + w = np.array(w) + + assert (w == weights).all() + + +def test_sample_timesteps(setup): + buff, env = setup + obs = env.reset() + buff.add_path({ + 'observations': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + 'next_observations': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + 'actions': + np.array([np.full_like(obs, i, dtype=np.float32) for i in range(50)]), + 'terminals': + np.array([[False] for _ in range(50)]), + 'rewards': + np.array([[1] for _ in range(50)]) + }) + + timesteps, weights, indices = buff.sample_timesteps(50) + + assert len(weights) == 50 + assert len(indices) == 50 + + obses, actions = timesteps.observations, timesteps.actions + + for a, o, i in zip(actions, obses, indices): + assert (o == i).all() + assert (a == i).all()