From a8db5d00055f351bc766005e90c8614f6ca2e43a Mon Sep 17 00:00:00 2001 From: Mishari Aliesa Date: Mon, 21 Sep 2020 17:19:25 -0700 Subject: [PATCH 1/5] Add torch DQN This also adds several smaller features: - torch/examples/watch_atari.py: use a trained agent to play atari. - Error handling in the snapshotter for invalid arguments. - torch/examples/dqn_atari.py: train on atari environments. --- examples/torch/dqn_atari.py | 4 ++++ tests/garage/torch/algos/test_dqn.py | 3 --- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index e6d11fe11a..231cf80c94 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -144,6 +144,10 @@ def dqn_atari(ctxt=None, qf = DiscreteCNNQFunction( env_spec=env.spec, +<<<<<<< HEAD +======= + minibatch_size=hyperparams['buffer_batch_size'], +>>>>>>> Add torch DQN hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], diff --git a/tests/garage/torch/algos/test_dqn.py b/tests/garage/torch/algos/test_dqn.py index ea3c157d7a..200b34512a 100644 --- a/tests/garage/torch/algos/test_dqn.py +++ b/tests/garage/torch/algos/test_dqn.py @@ -14,9 +14,6 @@ from garage.replay_buffer import PathBuffer from garage.sampler import LocalSampler from garage.torch import np_to_torch -from garage.torch.algos import DQN -from garage.torch.policies import DiscreteQFArgmaxPolicy -from garage.torch.q_functions import DiscreteMLPQFunction from garage.trainer import Trainer from tests.fixtures import snapshot_config From c95d205c678b18104efa310ad6424a865f8c0970 Mon Sep 17 00:00:00 2001 From: Mishari Aliesa Date: Mon, 21 Sep 2020 17:19:25 -0700 Subject: [PATCH 2/5] Add torch DQN This also adds several smaller features: - torch/examples/watch_atari.py: use a trained agent to play atari. - Error handling in the snapshotter for invalid arguments. - torch/examples/dqn_atari.py: train on atari environments. --- examples/torch/watch_atari.py | 4 ++++ src/garage/_functions.py | 13 +++++++++++++ .../exploration_policies/epsilon_greedy_policy.py | 9 +++++++++ tests/garage/torch/algos/test_dqn.py | 1 - 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/examples/torch/watch_atari.py b/examples/torch/watch_atari.py index b780d332c8..f208fe269f 100755 --- a/examples/torch/watch_atari.py +++ b/examples/torch/watch_atari.py @@ -58,7 +58,11 @@ def watch_atari(saved_dir, env=None, num_episodes=10): episode_data = rollout(env, exploration_policy.policy, animated=True, +<<<<<<< HEAD pause_per_frame=0.02) +======= + sleep=0.02) +>>>>>>> Add torch DQN ep_rewards = np.append(ep_rewards, np.sum(episode_data['rewards'])) print('Average Reward {}'.format(np.mean(ep_rewards))) diff --git a/src/garage/_functions.py b/src/garage/_functions.py index 19eae59767..352cbdb20f 100644 --- a/src/garage/_functions.py +++ b/src/garage/_functions.py @@ -69,7 +69,11 @@ def rollout(env, *, max_episode_length=np.inf, animated=False, +<<<<<<< HEAD pause_per_frame=None, +======= + sleep=None, +>>>>>>> Add torch DQN deterministic=False): """Sample a single episode of the agent in the environment. @@ -79,7 +83,11 @@ def rollout(env, max_episode_length (int): If the episode reaches this many timesteps, it is truncated. animated (bool): If true, render the environment after each step. +<<<<<<< HEAD pause_per_frame (float): Time to sleep between steps. Only relevant if +======= + sleep (float): Time to sleep between steps. Only relevant if +>>>>>>> Add torch DQN animated == true. deterministic (bool): If true, use the mean action returned by the stochastic policy instead of sampling from the returned action @@ -114,8 +122,13 @@ def rollout(env, if animated: env.visualize() while episode_length < (max_episode_length or np.inf): +<<<<<<< HEAD if pause_per_frame is not None: time.sleep(pause_per_frame) +======= + if sleep is not None: + time.sleep(sleep) +>>>>>>> Add torch DQN a, agent_info = agent.get_action(last_obs) if deterministic and 'mean' in agent_info: a = agent_info['mean'] diff --git a/src/garage/np/exploration_policies/epsilon_greedy_policy.py b/src/garage/np/exploration_policies/epsilon_greedy_policy.py index 3a93e44b2b..630adff9f3 100644 --- a/src/garage/np/exploration_policies/epsilon_greedy_policy.py +++ b/src/garage/np/exploration_policies/epsilon_greedy_policy.py @@ -48,6 +48,15 @@ def __init__(self, self._total_env_steps = 0 self._last_total_env_steps = 0 + @property + def epsilon(self): + """Float: the instantaneous level of exploration noise.""" + return self._epsilon + + @epsilon.setter + def epsilon(self, epsilon): + self._episilon = epsilon + def get_action(self, observation): """Get action from this policy for the input observation. diff --git a/tests/garage/torch/algos/test_dqn.py b/tests/garage/torch/algos/test_dqn.py index 200b34512a..1f40761bd1 100644 --- a/tests/garage/torch/algos/test_dqn.py +++ b/tests/garage/torch/algos/test_dqn.py @@ -26,7 +26,6 @@ def setup(): steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = 100 * steps_per_epoch * sampler_batch_size - env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) From 235a6779bf61826cdceac46c61c2b9b58e671bf3 Mon Sep 17 00:00:00 2001 From: Mishari Aliesa Date: Fri, 25 Sep 2020 17:48:29 -0700 Subject: [PATCH 3/5] Add Double DQN --- examples/torch/dqn_atari.py | 11 ++++++++++ examples/torch/watch_atari.py | 4 ---- src/garage/_functions.py | 13 ------------ .../epsilon_greedy_policy.py | 9 --------- src/garage/torch/algos/dqn.py | 20 +++++++++++++++---- 5 files changed, 27 insertions(+), 30 deletions(-) diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index 231cf80c94..c027860064 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -3,6 +3,8 @@ Here it creates a gym environment CartPole, and trains a DQN with 50k steps. """ +import math + import click import gym import numpy as np @@ -54,11 +56,13 @@ @click.option('--seed', default=24) @click.option('--n', type=int, default=psutil.cpu_count(logical=False)) @click.option('--buffer_size', type=int, default=None) +@click.option('--n_steps', type=float, default=None) @click.option('--max_episode_length', type=int, default=None) def main(env=None, seed=24, n=psutil.cpu_count(logical=False), buffer_size=None, + n_steps=None, max_episode_length=None): """Wrapper to setup the logging directory. @@ -73,6 +77,9 @@ def main(env=None, buffer_size (int): size of the replay buffer in transitions. If None, defaults to hyperparams['buffer_size']. This is used by the integration tests. + n_steps (float): Total number of environment steps to run for, not + not including evaluation. If this is not None, n_epochs will + be recalculated based on this value. max_episode_length (int): Max length of an episode. If None, defaults to the timelimit specific to the environment. Used by integration tests. @@ -81,6 +88,10 @@ def main(env=None, env += 'NoFrameskip-v4' logdir = 'data/local/experiment/' + env + if n_steps is not None: + hyperparams['n_epochs'] = math.ceil( + int(n_steps) / (hyperparams['steps_per_epoch'] * + hyperparams['sampler_batch_size'])) if buffer_size is not None: hyperparams['buffer_size'] = buffer_size diff --git a/examples/torch/watch_atari.py b/examples/torch/watch_atari.py index f208fe269f..b780d332c8 100755 --- a/examples/torch/watch_atari.py +++ b/examples/torch/watch_atari.py @@ -58,11 +58,7 @@ def watch_atari(saved_dir, env=None, num_episodes=10): episode_data = rollout(env, exploration_policy.policy, animated=True, -<<<<<<< HEAD pause_per_frame=0.02) -======= - sleep=0.02) ->>>>>>> Add torch DQN ep_rewards = np.append(ep_rewards, np.sum(episode_data['rewards'])) print('Average Reward {}'.format(np.mean(ep_rewards))) diff --git a/src/garage/_functions.py b/src/garage/_functions.py index 352cbdb20f..19eae59767 100644 --- a/src/garage/_functions.py +++ b/src/garage/_functions.py @@ -69,11 +69,7 @@ def rollout(env, *, max_episode_length=np.inf, animated=False, -<<<<<<< HEAD pause_per_frame=None, -======= - sleep=None, ->>>>>>> Add torch DQN deterministic=False): """Sample a single episode of the agent in the environment. @@ -83,11 +79,7 @@ def rollout(env, max_episode_length (int): If the episode reaches this many timesteps, it is truncated. animated (bool): If true, render the environment after each step. -<<<<<<< HEAD pause_per_frame (float): Time to sleep between steps. Only relevant if -======= - sleep (float): Time to sleep between steps. Only relevant if ->>>>>>> Add torch DQN animated == true. deterministic (bool): If true, use the mean action returned by the stochastic policy instead of sampling from the returned action @@ -122,13 +114,8 @@ def rollout(env, if animated: env.visualize() while episode_length < (max_episode_length or np.inf): -<<<<<<< HEAD if pause_per_frame is not None: time.sleep(pause_per_frame) -======= - if sleep is not None: - time.sleep(sleep) ->>>>>>> Add torch DQN a, agent_info = agent.get_action(last_obs) if deterministic and 'mean' in agent_info: a = agent_info['mean'] diff --git a/src/garage/np/exploration_policies/epsilon_greedy_policy.py b/src/garage/np/exploration_policies/epsilon_greedy_policy.py index 630adff9f3..3a93e44b2b 100644 --- a/src/garage/np/exploration_policies/epsilon_greedy_policy.py +++ b/src/garage/np/exploration_policies/epsilon_greedy_policy.py @@ -48,15 +48,6 @@ def __init__(self, self._total_env_steps = 0 self._last_total_env_steps = 0 - @property - def epsilon(self): - """Float: the instantaneous level of exploration noise.""" - return self._epsilon - - @epsilon.setter - def epsilon(self, epsilon): - self._episilon = epsilon - def get_action(self, observation): """Get action from this policy for the input observation. diff --git a/src/garage/torch/algos/dqn.py b/src/garage/torch/algos/dqn.py index f6c5f06716..5419c55218 100644 --- a/src/garage/torch/algos/dqn.py +++ b/src/garage/torch/algos/dqn.py @@ -32,6 +32,8 @@ class DQN(RLAlgorithm): n_train_steps (int): Training steps. eval_env (Environment): Evaluation environment. If None, a copy of the main environment is used for evaluation. + double_q (bool): Whether to use Double DQN. + See https://arxiv.org/abs/1509.06461. max_episode_length_eval (int or None): Maximum length of episodes used for off-policy evaluation. If `None`, defaults to `env_spec.max_episode_length`. @@ -67,6 +69,7 @@ def __init__( replay_buffer, exploration_policy=None, eval_env=None, + double_q=True, qf_optimizer=torch.optim.Adam, *, # Everything after this is numbers. steps_per_epoch=20, @@ -100,6 +103,7 @@ def __init__( self._steps_per_epoch = steps_per_epoch self._n_train_steps = n_train_steps self._buffer_batch_size = buffer_batch_size + self._double_q = double_q self._discount = discount self._reward_scale = reward_scale self.max_episode_length = env_spec.max_episode_length @@ -246,10 +250,18 @@ def _optimize_qf(self, timesteps): next_inputs = next_observations inputs = observations with torch.no_grad(): - # discrete, outputs Qs for all possible actions - target_qvals = self._target_qf(next_inputs) - best_qvals, _ = torch.max(target_qvals, 1) - best_qvals = best_qvals.unsqueeze(1) + if self._double_q: + # Use online qf to get optimal actions + selected_actions = torch.argmax(self._qf(next_inputs), axis=1) + # use target qf to get Q values for those actions + selected_actions = selected_actions.long().unsqueeze(1) + best_qvals = torch.gather(self._target_qf(next_inputs), + dim=1, + index=selected_actions) + else: + target_qvals = self._target_qf(next_inputs) + best_qvals, _ = torch.max(target_qvals, 1) + best_qvals = best_qvals.unsqueeze(1) rewards_clipped = rewards if self._clip_reward is not None: From a6c34382a53240d0e2f09d13b70b83801019d110 Mon Sep 17 00:00:00 2001 From: Mishari Aliesa Date: Fri, 9 Oct 2020 19:04:26 -0700 Subject: [PATCH 4/5] Add Dueling DQN --- examples/torch/dqn_atari.py | 6 +- .../torch/modules/discrete_cnn_module.py | 72 +++++++++++++++---- .../q_functions/discrete_cnn_q_function.py | 4 ++ .../torch/modules/test_discrete_cnn_module.py | 67 +++++++++++++++++ 4 files changed, 134 insertions(+), 15 deletions(-) diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index c027860064..bd9fc23f42 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -41,6 +41,8 @@ target_update_freq=2, buffer_batch_size=32, max_epsilon=1.0, + double=True, + dueling=True, min_epsilon=0.01, decay_ratio=0.1, buffer_size=int(1e4), @@ -104,7 +106,7 @@ def main(env=None, # pylint: disable=unused-argument -@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=30) +@wrap_experiment(snapshot_mode='gap_overwrite', snapshot_gap=50) def dqn_atari(ctxt=None, env=None, seed=24, @@ -162,6 +164,7 @@ def dqn_atari(ctxt=None, hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], + dueling=hyperparams['dueling'], hidden_w_init=( lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))), hidden_sizes=hyperparams['hidden_sizes'], @@ -183,6 +186,7 @@ def dqn_atari(ctxt=None, replay_buffer=replay_buffer, steps_per_epoch=steps_per_epoch, qf_lr=hyperparams['lr'], + double_q=hyperparams['double'], clip_gradient=hyperparams['clip_gradient'], discount=hyperparams['discount'], min_buffer_size=hyperparams['min_buffer_size'], diff --git a/src/garage/torch/modules/discrete_cnn_module.py b/src/garage/torch/modules/discrete_cnn_module.py index be29e5e8a0..885352d85a 100644 --- a/src/garage/torch/modules/discrete_cnn_module.py +++ b/src/garage/torch/modules/discrete_cnn_module.py @@ -31,6 +31,8 @@ class DiscreteCNNModule(nn.Module): hidden_sizes (list[int]): Output dimension of dense layer(s) for the MLP for mean. For example, (32, 32) means the MLP consists of two hidden layers, each with 32 hidden units. + dueling (bool): Whether to use a dueling architecture for the + fully-connected layer. mlp_hidden_nonlinearity (callable): Activation function for intermediate dense layer(s) in the MLP. It should return a torch.Tensor. Set it to None to maintain a linear activation. @@ -73,6 +75,7 @@ def __init__(self, hidden_channels, strides, hidden_sizes=(32, 32), + dueling=False, cnn_hidden_nonlinearity=nn.ReLU, mlp_hidden_nonlinearity=nn.ReLU, hidden_w_init=nn.init.xavier_uniform_, @@ -90,6 +93,8 @@ def __init__(self, super().__init__() + self._dueling = dueling + input_var = torch.zeros(input_shape) cnn_module = CNNModule(input_var=input_var, kernel_sizes=kernel_sizes, @@ -109,22 +114,54 @@ def __init__(self, with torch.no_grad(): cnn_out = cnn_module(input_var) flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1] - mlp_module = MLPModule(flat_dim, - output_dim, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) - if mlp_hidden_nonlinearity is None: - self._module = nn.Sequential(cnn_module, nn.Flatten(), mlp_module) + if dueling: + self._val = MLPModule(flat_dim, + 1, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + self._act = MLPModule(flat_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + if mlp_hidden_nonlinearity is None: + self._module = nn.Sequential(cnn_module, nn.Flatten()) + else: + self._module = nn.Sequential(cnn_module, + mlp_hidden_nonlinearity(), + nn.Flatten()) + else: - self._module = nn.Sequential(cnn_module, mlp_hidden_nonlinearity(), - nn.Flatten(), mlp_module) + mlp_module = MLPModule(flat_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + + if mlp_hidden_nonlinearity is None: + self._module = nn.Sequential(cnn_module, nn.Flatten(), + mlp_module) + else: + self._module = nn.Sequential(cnn_module, + mlp_hidden_nonlinearity(), + nn.Flatten(), mlp_module) def forward(self, inputs): """Forward method. @@ -137,4 +174,11 @@ def forward(self, inputs): torch.Tensor: Output tensor of shape :math:`(N, output_dim)`. """ + if self._dueling: + out = self._module(inputs) + val = self._val(out) + act = self._act(out) + act = act - act.mean(1).unsqueeze(1) + return val + act + return self._module(inputs) diff --git a/src/garage/torch/q_functions/discrete_cnn_q_function.py b/src/garage/torch/q_functions/discrete_cnn_q_function.py index 4550ef52da..02a5987dba 100644 --- a/src/garage/torch/q_functions/discrete_cnn_q_function.py +++ b/src/garage/torch/q_functions/discrete_cnn_q_function.py @@ -27,6 +27,8 @@ class DiscreteCNNQFunction(DiscreteCNNModule): For example, (3, 32) means there are two convolutional layers. The filter for the first conv layer outputs 3 channels and the second one outputs 32 channels. + dueling (bool): Whether to use a dueling architecture for the + fully-connected layer. hidden_sizes (list[int]): Output dimension of dense layer(s) for the MLP for mean. For example, (32, 32) means the MLP consists of two hidden layers, each with 32 hidden units. @@ -70,6 +72,7 @@ def __init__(self, kernel_sizes, hidden_channels, strides, + dueling=False, hidden_sizes=(32, 32), cnn_hidden_nonlinearity=torch.nn.ReLU, mlp_hidden_nonlinearity=torch.nn.ReLU, @@ -94,6 +97,7 @@ def __init__(self, kernel_sizes=kernel_sizes, strides=strides, hidden_sizes=hidden_sizes, + dueling=dueling, hidden_channels=hidden_channels, cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, diff --git a/tests/garage/torch/modules/test_discrete_cnn_module.py b/tests/garage/torch/modules/test_discrete_cnn_module.py index 82353edaf6..ef902723fc 100644 --- a/tests/garage/torch/modules/test_discrete_cnn_module.py +++ b/tests/garage/torch/modules/test_discrete_cnn_module.py @@ -65,6 +65,73 @@ def test_output_values(output_dim, kernel_sizes, hidden_channels, strides, assert torch.all(torch.eq(output.detach(), module(obs).detach())) +@pytest.mark.parametrize( + 'output_dim, kernel_sizes, hidden_channels, strides, paddings', [ + (1, (1, ), (32, ), (1, ), (0, )), + (2, (3, ), (32, ), (1, ), (0, )), + (5, (3, ), (32, ), (2, ), (0, )), + (5, (5, ), (12, ), (1, ), (2, )), + (5, (1, 1), (32, 64), (1, 1), (0, 0)), + (10, (3, 3), (32, 64), (1, 1), (0, 0)), + (10, (3, 3), (32, 64), (2, 2), (0, 0)), + ]) +def test_dueling_output_values(output_dim, kernel_sizes, hidden_channels, + strides, paddings): + + batch_size = 64 + input_width = 32 + input_height = 32 + in_channel = 3 + input_shape = (batch_size, in_channel, input_height, input_width) + obs = torch.rand(input_shape) + + module = DiscreteCNNModule(input_shape=input_shape, + output_dim=output_dim, + hidden_channels=hidden_channels, + hidden_sizes=hidden_channels, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + padding_mode='zeros', + dueling=True, + hidden_w_init=nn.init.ones_, + output_w_init=nn.init.ones_, + is_image=False) + + cnn = CNNModule(input_var=obs, + hidden_channels=hidden_channels, + kernel_sizes=kernel_sizes, + strides=strides, + paddings=paddings, + padding_mode='zeros', + hidden_w_init=nn.init.ones_, + is_image=False) + flat_dim = torch.flatten(cnn(obs).detach(), start_dim=1).shape[1] + + mlp_adv = MLPModule( + flat_dim, + output_dim, + hidden_channels, + hidden_w_init=nn.init.ones_, + output_w_init=nn.init.ones_, + ) + + mlp_val = MLPModule( + flat_dim, + 1, + hidden_channels, + hidden_w_init=nn.init.ones_, + output_w_init=nn.init.ones_, + ) + + cnn_out = cnn(obs) + val = mlp_val(torch.flatten(cnn_out, start_dim=1)) + adv = mlp_adv(torch.flatten(cnn_out, start_dim=1)) + output = val + (adv - adv.mean(1).unsqueeze(1)) + + assert torch.all(torch.eq(output.detach(), module(obs).detach())) + + @pytest.mark.parametrize('output_dim, hidden_channels, kernel_sizes, strides', [(1, (32, ), (1, ), (1, ))]) def test_without_nonlinearity(output_dim, hidden_channels, kernel_sizes, From f40ee60e512453e1cb3611c5d701931bbad5e243 Mon Sep 17 00:00:00 2001 From: Mishari Aliesa Date: Wed, 21 Oct 2020 15:46:06 -0700 Subject: [PATCH 5/5] Add noisy MLP --- examples/torch/dqn_atari.py | 29 +-- src/garage/torch/algos/dqn.py | 3 + src/garage/torch/modules/__init__.py | 2 + .../torch/modules/discrete_cnn_module.py | 129 +++++++--- .../torch/modules/multi_headed_mlp_module.py | 4 +- src/garage/torch/modules/noisy_mlp_module.py | 221 ++++++++++++++++++ .../q_functions/discrete_cnn_q_function.py | 13 ++ tests/garage/torch/algos/test_dqn.py | 4 + .../torch/modules/test_noisy_mlp_module.py | 123 ++++++++++ 9 files changed, 482 insertions(+), 46 deletions(-) create mode 100644 src/garage/torch/modules/noisy_mlp_module.py create mode 100644 tests/garage/torch/modules/test_noisy_mlp_module.py diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index bd9fc23f42..490751db9c 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -31,7 +31,7 @@ from garage.torch.q_functions import DiscreteCNNQFunction from garage.trainer import Trainer -hyperparams = dict(n_epochs=500, +hyperparams = dict(n_epochs=1000, steps_per_epoch=20, sampler_batch_size=500, lr=1e-4, @@ -42,7 +42,9 @@ buffer_batch_size=32, max_epsilon=1.0, double=True, - dueling=True, + dueling=False, + noisy=True, + noisy_sigma=0.5, min_epsilon=0.01, decay_ratio=0.1, buffer_size=int(1e4), @@ -157,27 +159,28 @@ def dqn_atari(ctxt=None, qf = DiscreteCNNQFunction( env_spec=env.spec, -<<<<<<< HEAD -======= - minibatch_size=hyperparams['buffer_batch_size'], ->>>>>>> Add torch DQN hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], dueling=hyperparams['dueling'], + noisy=hyperparams['noisy'], + noisy_sigma=hyperparams['noisy_sigma'], hidden_w_init=( lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))), hidden_sizes=hyperparams['hidden_sizes'], is_image=True) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) - exploration_policy = EpsilonGreedyPolicy( - env_spec=env.spec, - policy=policy, - total_timesteps=num_timesteps, - max_epsilon=hyperparams['max_epsilon'], - min_epsilon=hyperparams['min_epsilon'], - decay_ratio=hyperparams['decay_ratio']) + + exploration_policy = policy + if not hyperparams['noisy']: + exploration_policy = EpsilonGreedyPolicy( + env_spec=env.spec, + policy=policy, + total_timesteps=num_timesteps, + max_epsilon=hyperparams['max_epsilon'], + min_epsilon=hyperparams['min_epsilon'], + decay_ratio=hyperparams['decay_ratio']) algo = DQN(env_spec=env.spec, policy=policy, diff --git a/src/garage/torch/algos/dqn.py b/src/garage/torch/algos/dqn.py index 5419c55218..a4c35e36b8 100644 --- a/src/garage/torch/algos/dqn.py +++ b/src/garage/torch/algos/dqn.py @@ -227,6 +227,9 @@ def _log_eval_results(self, epoch): tabular.record('QFunction/MaxY', np.max(self._epoch_ys)) tabular.record('QFunction/AverageAbsY', np.mean(np.abs(self._epoch_ys))) + # log noise levels if using a NoisyNet. + if hasattr(self._qf, 'log_noise'): + self._qf.log_noise('QFunction/Noisy-Sigma') def _optimize_qf(self, timesteps): """Perform algorithm optimizing. diff --git a/src/garage/torch/modules/__init__.py b/src/garage/torch/modules/__init__.py index 1e07d6b04a..69f689878f 100644 --- a/src/garage/torch/modules/__init__.py +++ b/src/garage/torch/modules/__init__.py @@ -10,6 +10,7 @@ from garage.torch.modules.gaussian_mlp_module import GaussianMLPModule from garage.torch.modules.mlp_module import MLPModule from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule +from garage.torch.modules.noisy_mlp_module import NoisyMLPModule # DiscreteCNNModule must go after MLPModule from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule # yapf: enable @@ -20,6 +21,7 @@ 'DiscreteCNNModule', 'MLPModule', 'MultiHeadedMLPModule', + 'NoisyMLPModule', 'GaussianMLPModule', 'GaussianMLPIndependentStdModule', 'GaussianMLPTwoHeadedModule', diff --git a/src/garage/torch/modules/discrete_cnn_module.py b/src/garage/torch/modules/discrete_cnn_module.py index 885352d85a..37c5ba62e4 100644 --- a/src/garage/torch/modules/discrete_cnn_module.py +++ b/src/garage/torch/modules/discrete_cnn_module.py @@ -1,8 +1,9 @@ """Discrete CNN Q Function.""" +from dowel import tabular import torch from torch import nn -from garage.torch.modules import CNNModule, MLPModule +from garage.torch.modules import CNNModule, MLPModule, NoisyMLPModule # pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 @@ -33,6 +34,13 @@ class DiscreteCNNModule(nn.Module): of two hidden layers, each with 32 hidden units. dueling (bool): Whether to use a dueling architecture for the fully-connected layer. + noisy (bool): Whether to use parameter noise for the fully-connected + layers. If True, hidden_w_init, hidden_b_init, output_w_init, and + output_b_init are ignored. + noisy_sigma (float): Level of scaling to apply to the parameter noise. + This is ignored if noisy is set to False. + std_noise (float): Standard deviation of the gaussian parameters noise. + This is ignored if noisy is set to False. mlp_hidden_nonlinearity (callable): Activation function for intermediate dense layer(s) in the MLP. It should return a torch.Tensor. Set it to None to maintain a linear activation. @@ -81,6 +89,9 @@ def __init__(self, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, + noisy=True, + noisy_sigma=0.5, + std_noise=1., padding_mode='zeros', max_pool=False, pool_shape=None, @@ -94,6 +105,8 @@ def __init__(self, super().__init__() self._dueling = dueling + self._noisy = noisy + self._noisy_layers = None input_var = torch.zeros(input_shape) cnn_module = CNNModule(input_var=input_var, @@ -116,26 +129,49 @@ def __init__(self, flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1] if dueling: - self._val = MLPModule(flat_dim, - 1, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) - self._act = MLPModule(flat_dim, - output_dim, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) + if noisy: + self._val = NoisyMLPModule( + flat_dim, + 1, + hidden_sizes, + sigma_naught=noisy_sigma, + std_noise=std_noise, + hidden_nonlinearity=mlp_hidden_nonlinearity, + output_nonlinearity=output_nonlinearity) + self._act = NoisyMLPModule( + flat_dim, + output_dim, + hidden_sizes, + sigma_naught=noisy_sigma, + std_noise=std_noise, + hidden_nonlinearity=mlp_hidden_nonlinearity, + output_nonlinearity=output_nonlinearity) + self._noisy_layers = [self._val, self._act] + else: + self._val = MLPModule( + flat_dim, + 1, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + + self._act = MLPModule( + flat_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + if mlp_hidden_nonlinearity is None: self._module = nn.Sequential(cnn_module, nn.Flatten()) else: @@ -144,16 +180,29 @@ def __init__(self, nn.Flatten()) else: - mlp_module = MLPModule(flat_dim, - output_dim, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) + mlp_module = None + if noisy: + mlp_module = NoisyMLPModule( + flat_dim, + output_dim, + hidden_sizes, + sigma_naught=noisy_sigma, + std_noise=std_noise, + hidden_nonlinearity=mlp_hidden_nonlinearity, + output_nonlinearity=output_nonlinearity) + self._noisy_layers = [mlp_module] + else: + mlp_module = MLPModule( + flat_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) if mlp_hidden_nonlinearity is None: self._module = nn.Sequential(cnn_module, nn.Flatten(), @@ -182,3 +231,21 @@ def forward(self, inputs): return val + act return self._module(inputs) + + def log_noise(self, key): + """Log sigma levels for noisy layers. + + Args: + key (str): Prefix to use for logging. + + """ + if self._noisy: + layer_num = 0 + for layer in self._noisy_layers: + for name, param in layer.named_parameters(): + if name.endswith('weight_sigma'): + layer_num += 1 + sigma_mean = float( + (param**2).mean().sqrt().data.cpu().numpy()) + tabular.record(key + '_layer_' + str(layer_num), + sigma_mean) diff --git a/src/garage/torch/modules/multi_headed_mlp_module.py b/src/garage/torch/modules/multi_headed_mlp_module.py index fcb4479744..7c625459bd 100644 --- a/src/garage/torch/modules/multi_headed_mlp_module.py +++ b/src/garage/torch/modules/multi_headed_mlp_module.py @@ -7,6 +7,8 @@ from garage.torch import NonLinearity +# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 +# pylint: disable=abstract-method class MultiHeadedMLPModule(nn.Module): """MultiHeadedMLPModule Model. @@ -71,8 +73,6 @@ def __init__(self, output_nonlinearities = self._check_parameter_for_output_layer( 'output_nonlinearities', output_nonlinearities, n_heads) - self._layers = nn.ModuleList() - prev_size = input_dim for size in hidden_sizes: hidden_layers = nn.Sequential() diff --git a/src/garage/torch/modules/noisy_mlp_module.py b/src/garage/torch/modules/noisy_mlp_module.py new file mode 100644 index 0000000000..f6c1d4fca5 --- /dev/null +++ b/src/garage/torch/modules/noisy_mlp_module.py @@ -0,0 +1,221 @@ +"""Noisy MLP Module.""" + +import math + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F + +from garage.torch import NonLinearity + + +# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 +# pylint: disable=abstract-method +class NoisyMLPModule(nn.Module): + """MLP with factorised gaussian parameter noise. + + See https://arxiv.org/pdf/1706.10295.pdf. + + This module creates a multilayered perceptron (MLP) in where the + linear layers are replaced with :class:`~NoisyLinear` layers. + See the docstring of :class:`~NoisyLinear` and the linked paper + for more details. + + Args: + input_dim (int) : Dimension of the network input. + output_dim (int): Dimension of the network output. + hidden_sizes (list[int]): Output dimension of dense layer(s). + For example, (32, 32) means this MLP consists of two + hidden layers, each with 32 hidden units. + hidden_nonlinearity (callable or torch.nn.Module): Activation function + for intermediate dense layer(s). It should return a torch.Tensor. + Set it to None to maintain a linear activation. + sigma_naught (float): Hyperparameter that specifies the intial noise + scaling factor. See the paper for details. + std_noise (float): Standard deviation of the gaussian noise + distribution to sample from. + output_nonlinearity (callable or torch.nn.Module): Activation function + for output dense layer. It should return a torch.Tensor. + Set it to None to maintain a linear activation. + """ + + def __init__(self, + input_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=F.relu, + sigma_naught=0.5, + std_noise=1., + output_nonlinearity=None): + super().__init__() + self._layers = nn.ModuleList() + self._noisy_layers = [] + + prev_size = input_dim + for size in hidden_sizes: + hidden_layers = nn.Sequential() + linear_layer = NoisyLinear(prev_size, size, sigma_naught, + std_noise) + self._noisy_layers.append(linear_layer) + hidden_layers.add_module('linear', linear_layer) + + if hidden_nonlinearity: + hidden_layers.add_module('non_linearity', + NonLinearity(hidden_nonlinearity)) + + self._layers.append(hidden_layers) + prev_size = size + + self._output_layers = nn.ModuleList() + output_layer = nn.Sequential() + linear_layer = NoisyLinear(prev_size, output_dim, sigma_naught, + std_noise) + self._noisy_layers.append(linear_layer) + output_layer.add_module('linear', linear_layer) + + if output_nonlinearity: + output_layer.add_module('non_linearity', + NonLinearity(output_nonlinearity)) + + self._output_layers.append(output_layer) + + def forward(self, input_val): + """Forward method. + + Args: + input_val (torch.Tensor): Input values with (N, *, input_dim) + shape. + + Returns: + List[torch.Tensor]: Output values + + """ + x = input_val + for layer in self._layers: + x = layer(x) + + return self._output_layers[0](x) + + def set_deterministic(self, deterministic): + """Set whether or not noise is applied. + + This is useful when determinstic evaluation of + a policy is desired. Determinism is off by default. + + Args: + deterministic (bool): If False, noise is applied, else + it is not. + """ + for layer in self._noisy_layers: + layer.set_deterministic(deterministic) + + +class NoisyLinear(nn.Module): + r"""Noisy linear layer with Factorised Gaussian noise. + + See https://arxiv.org/pdf/1706.10295.pdf. + + Each NoisyLinear layer applies the following transformation + + :math:`y = (\mu^w + \sigma^w \odot \epsilon ^w) + \mu^b + \sigma^b \odot + \epsilon^b` + + where :math:`\mu^w, \mu^b, \sigma^w, and \sigma^b` are learned parameters + and :math:`\epislon^w, \epsilon^b` are zero-mean gaussian noise samples. + + Args: + input_dim (int) : Dimension of the network input. + output_dim (int): Dimension of the network output. + sigma_naught (float): Hyperparameter that specifies the intial noise + scaling factor. See the paper for details. + std_noise (float): Standard deviation of the gaussian noise + distribution to sample from. + """ + + def __init__(self, input_dim, output_dim, sigma_naught=0.5, std_noise=1.): + super().__init__() + self._input_dim = input_dim + self._output_dim = output_dim + self._sigma_naught = sigma_naught + self._std_noise = std_noise + self._deterministic = False + + self._output_dim = output_dim + + self._weight_mu = nn.Parameter(torch.FloatTensor( + output_dim, input_dim)) + self._weight_sigma = nn.Parameter( + torch.FloatTensor(output_dim, input_dim)) + + self._bias_mu = nn.Parameter(torch.FloatTensor(output_dim)) + self._bias_sigma = nn.Parameter(torch.FloatTensor(output_dim)) + + # epsilon noise + self.register_buffer('weight_epsilon', + torch.FloatTensor(output_dim, input_dim)) + self.register_buffer('bias_epsilon', torch.FloatTensor(output_dim)) + + self.reset_parameters() + + def set_deterministic(self, deterministic): + """Set whether or not noise is applied. + + This is useful when determinstic evaluation of + a policy is desired. Determinism is off by default. + + Args: + deterministic (bool): If False, noise is applied, else + it is not. + """ + self._deterministic = deterministic + + def forward(self, input_value): + """Forward method. + + Args: + input_value (torch.Tensor): Input values with (N, *, input_dim) + shape. + + Returns: + torch.Tensor: Output value + """ + if self._deterministic: + return F.linear(input_value, self._weight_mu, self._bias_mu) + + self._sample_noise() + w = self._weight_mu + self._weight_sigma.mul( + Variable(self.weight_epsilon)) + b = self._bias_mu + self._bias_sigma.mul(Variable(self.bias_epsilon)) + return F.linear(input_value, w, b) + + def reset_parameters(self): + """Reset all learnable parameters.""" + mu_range = 1 / math.sqrt(self._weight_mu.size(1)) + + self._weight_mu.data.uniform_(-mu_range, mu_range) + self._weight_sigma.data.fill_(self._sigma_naught / + math.sqrt(self._weight_sigma.size(1))) + + self._bias_mu.data.uniform_(-mu_range, mu_range) + self._bias_sigma.data.fill_(self._sigma_naught / + math.sqrt(self._bias_sigma.size(0))) + + def _sample_noise(self): + r"""Sample and assign new values for :math:`\epsilon`.""" + in_noise = self._get_noise(self._input_dim) + out_noise = self._get_noise(self._output_dim) + self.weight_epsilon.copy_(out_noise.ger(in_noise)) + self.bias_epsilon.copy_(self._get_noise(self._output_dim)) + + def _get_noise(self, size): + """Retrieve scaled zero-mean gaussian noise. + + Args: + size (int): size of the noise vector. + + Returns: + torch.Tensor: noise vector of the specified size. + """ + x = torch.normal(torch.zeros(size), self._std_noise * torch.ones(size)) + return x.sign().mul(x.abs().sqrt()) diff --git a/src/garage/torch/q_functions/discrete_cnn_q_function.py b/src/garage/torch/q_functions/discrete_cnn_q_function.py index 02a5987dba..47d54f5967 100644 --- a/src/garage/torch/q_functions/discrete_cnn_q_function.py +++ b/src/garage/torch/q_functions/discrete_cnn_q_function.py @@ -29,6 +29,13 @@ class DiscreteCNNQFunction(DiscreteCNNModule): and the second one outputs 32 channels. dueling (bool): Whether to use a dueling architecture for the fully-connected layer. + noisy (bool): Whether to use parameter noise for the fully-connected + layers. If True, hidden_w_init, hidden_b_init, output_w_init, and + output_b_init are ignored. + noisy_sigma (float): Level of scaling to apply to the parameter noise. + This is ignored if noisy is set to False. + std_noise (float): Standard deviation of the gaussian parameters noise. + This is ignored if noisy is set to False. hidden_sizes (list[int]): Output dimension of dense layer(s) for the MLP for mean. For example, (32, 32) means the MLP consists of two hidden layers, each with 32 hidden units. @@ -73,6 +80,9 @@ def __init__(self, hidden_channels, strides, dueling=False, + noisy=False, + noisy_sigma=0.5, + std_noise=1., hidden_sizes=(32, 32), cnn_hidden_nonlinearity=torch.nn.ReLU, mlp_hidden_nonlinearity=torch.nn.ReLU, @@ -98,6 +108,9 @@ def __init__(self, strides=strides, hidden_sizes=hidden_sizes, dueling=dueling, + noisy=noisy, + noisy_sigma=noisy_sigma, + std_noise=std_noise, hidden_channels=hidden_channels, cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, diff --git a/tests/garage/torch/algos/test_dqn.py b/tests/garage/torch/algos/test_dqn.py index 1f40761bd1..ea3c157d7a 100644 --- a/tests/garage/torch/algos/test_dqn.py +++ b/tests/garage/torch/algos/test_dqn.py @@ -14,6 +14,9 @@ from garage.replay_buffer import PathBuffer from garage.sampler import LocalSampler from garage.torch import np_to_torch +from garage.torch.algos import DQN +from garage.torch.policies import DiscreteQFArgmaxPolicy +from garage.torch.q_functions import DiscreteMLPQFunction from garage.trainer import Trainer from tests.fixtures import snapshot_config @@ -26,6 +29,7 @@ def setup(): steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = 100 * steps_per_epoch * sampler_batch_size + env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) diff --git a/tests/garage/torch/modules/test_noisy_mlp_module.py b/tests/garage/torch/modules/test_noisy_mlp_module.py new file mode 100644 index 0000000000..75eaa94eee --- /dev/null +++ b/tests/garage/torch/modules/test_noisy_mlp_module.py @@ -0,0 +1,123 @@ +"""Test NoisyMLPModule.""" +import math + +import numpy as np +import pytest +import torch +from torch.autograd import Variable +from torch.nn import functional as F + +from garage.torch.modules import NoisyMLPModule +from garage.torch.modules.noisy_mlp_module import NoisyLinear + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught, hidden_sizes', + [(1, 1, 0.1, (32, 32)), (2, 2, 0.5, (32, 64)), + (2, 3, 1., (5, 5, 5))]) +def test_forward(input_dim, output_dim, sigma_naught, hidden_sizes): + noisy_mlp = NoisyMLPModule(input_dim, + output_dim, + hidden_nonlinearity=None, + sigma_naught=sigma_naught, + hidden_sizes=hidden_sizes) + + # mock the noise + for layer in noisy_mlp._noisy_layers: + layer._get_noise = lambda x: torch.Tensor(np.ones(x)) + + input_val = torch.Tensor(np.ones(input_dim, dtype=np.float32)) + x = input_val + for layer in noisy_mlp._noisy_layers: + x = layer(x) + + out = noisy_mlp.forward(input_val) + assert (x == out).all() + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught, hidden_sizes', + [(1, 1, 0.1, (32, 32)), (2, 2, 0.5, (32, 64)), + (2, 3, 1., (5, 5, 5))]) +def test_forward_deterministic(input_dim, output_dim, sigma_naught, + hidden_sizes): + noisy_mlp = NoisyMLPModule(input_dim, + output_dim, + hidden_nonlinearity=None, + sigma_naught=sigma_naught, + hidden_sizes=hidden_sizes) + noisy_mlp.set_deterministic(True) + input_val = torch.Tensor(np.ones(input_dim, dtype=np.float32)) + x = input_val + for layer in noisy_mlp._noisy_layers: + x = layer(x) + + out = noisy_mlp.forward(input_val) + assert (x == out).all() + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught', [(1, 1, 0.1), + (2, 2, 0.5), + (2, 3, 1.)]) +def test_noisy_linear_reset_parameters(input_dim, output_dim, sigma_naught): + layer = NoisyLinear(input_dim, output_dim, sigma_naught=0) + + mu_range = 1 / math.sqrt(input_dim) + assert (layer._weight_sigma == 0.).all() + assert (layer._bias_sigma == 0.).all() + + layer._sigma_naught = sigma_naught + layer.reset_parameters() + + bias_sig = sigma_naught / math.sqrt(output_dim) + weight_sig = sigma_naught / math.sqrt(input_dim) + + # sigma + assert (layer._weight_sigma == weight_sig).all() + assert (layer._bias_sigma == bias_sig).all() + + # mu + assert ((layer._bias_mu <= mu_range).all() + and (layer._bias_mu >= -mu_range).all()) + + assert ((layer._weight_mu <= mu_range).all() + and (layer._weight_mu >= -mu_range).all()) + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught', [(1, 1, 0.1), + (2, 2, 0.5), + (2, 3, 1.)]) +def test_noisy_linear_forward(input_dim, output_dim, sigma_naught): + layer = NoisyLinear(input_dim, output_dim, sigma_naught=sigma_naught) + + input_val = torch.Tensor(np.ones(input_dim, dtype=np.float32)) + val = layer.forward(input_val).detach() + w = layer._weight_mu + \ + layer._weight_sigma.mul(Variable(layer.weight_epsilon)) + b = layer._bias_mu + layer._bias_sigma.mul(Variable(layer.bias_epsilon)) + + expected = F.linear(input_val, w, b).detach() + + assert (val == expected).all() + + # test deterministic mode + + layer.set_deterministic(True) + val = layer.forward(input_val) + expected = F.linear(input_val, layer._weight_mu, layer._bias_mu) + + assert (val == expected).all() + + +@pytest.mark.parametrize('input_dim, output_dim', [(1, 1), (2, 2), (2, 3)]) +def test_sample_noise(input_dim, output_dim): + layer = NoisyLinear(input_dim, output_dim) + + # mock the noise + layer._get_noise = lambda x: torch.Tensor(np.ones(x)) + + out_noise = layer._get_noise(output_dim) + in_noise = layer._get_noise(input_dim) + layer._sample_noise() + + expected = out_noise.ger(in_noise).detach() + assert (layer.weight_epsilon == expected).all() + assert (layer.bias_epsilon == layer._get_noise(output_dim)).all()