diff --git a/examples/torch/dqn_atari.py b/examples/torch/dqn_atari.py index bd9fc23f42..490751db9c 100755 --- a/examples/torch/dqn_atari.py +++ b/examples/torch/dqn_atari.py @@ -31,7 +31,7 @@ from garage.torch.q_functions import DiscreteCNNQFunction from garage.trainer import Trainer -hyperparams = dict(n_epochs=500, +hyperparams = dict(n_epochs=1000, steps_per_epoch=20, sampler_batch_size=500, lr=1e-4, @@ -42,7 +42,9 @@ buffer_batch_size=32, max_epsilon=1.0, double=True, - dueling=True, + dueling=False, + noisy=True, + noisy_sigma=0.5, min_epsilon=0.01, decay_ratio=0.1, buffer_size=int(1e4), @@ -157,27 +159,28 @@ def dqn_atari(ctxt=None, qf = DiscreteCNNQFunction( env_spec=env.spec, -<<<<<<< HEAD -======= - minibatch_size=hyperparams['buffer_batch_size'], ->>>>>>> Add torch DQN hidden_channels=hyperparams['hidden_channels'], kernel_sizes=hyperparams['kernel_sizes'], strides=hyperparams['strides'], dueling=hyperparams['dueling'], + noisy=hyperparams['noisy'], + noisy_sigma=hyperparams['noisy_sigma'], hidden_w_init=( lambda x: torch.nn.init.orthogonal_(x, gain=np.sqrt(2))), hidden_sizes=hyperparams['hidden_sizes'], is_image=True) policy = DiscreteQFArgmaxPolicy(env_spec=env.spec, qf=qf) - exploration_policy = EpsilonGreedyPolicy( - env_spec=env.spec, - policy=policy, - total_timesteps=num_timesteps, - max_epsilon=hyperparams['max_epsilon'], - min_epsilon=hyperparams['min_epsilon'], - decay_ratio=hyperparams['decay_ratio']) + + exploration_policy = policy + if not hyperparams['noisy']: + exploration_policy = EpsilonGreedyPolicy( + env_spec=env.spec, + policy=policy, + total_timesteps=num_timesteps, + max_epsilon=hyperparams['max_epsilon'], + min_epsilon=hyperparams['min_epsilon'], + decay_ratio=hyperparams['decay_ratio']) algo = DQN(env_spec=env.spec, policy=policy, diff --git a/src/garage/torch/algos/dqn.py b/src/garage/torch/algos/dqn.py index 5419c55218..a4c35e36b8 100644 --- a/src/garage/torch/algos/dqn.py +++ b/src/garage/torch/algos/dqn.py @@ -227,6 +227,9 @@ def _log_eval_results(self, epoch): tabular.record('QFunction/MaxY', np.max(self._epoch_ys)) tabular.record('QFunction/AverageAbsY', np.mean(np.abs(self._epoch_ys))) + # log noise levels if using a NoisyNet. + if hasattr(self._qf, 'log_noise'): + self._qf.log_noise('QFunction/Noisy-Sigma') def _optimize_qf(self, timesteps): """Perform algorithm optimizing. diff --git a/src/garage/torch/modules/__init__.py b/src/garage/torch/modules/__init__.py index 1e07d6b04a..69f689878f 100644 --- a/src/garage/torch/modules/__init__.py +++ b/src/garage/torch/modules/__init__.py @@ -10,6 +10,7 @@ from garage.torch.modules.gaussian_mlp_module import GaussianMLPModule from garage.torch.modules.mlp_module import MLPModule from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule +from garage.torch.modules.noisy_mlp_module import NoisyMLPModule # DiscreteCNNModule must go after MLPModule from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule # yapf: enable @@ -20,6 +21,7 @@ 'DiscreteCNNModule', 'MLPModule', 'MultiHeadedMLPModule', + 'NoisyMLPModule', 'GaussianMLPModule', 'GaussianMLPIndependentStdModule', 'GaussianMLPTwoHeadedModule', diff --git a/src/garage/torch/modules/discrete_cnn_module.py b/src/garage/torch/modules/discrete_cnn_module.py index 885352d85a..37c5ba62e4 100644 --- a/src/garage/torch/modules/discrete_cnn_module.py +++ b/src/garage/torch/modules/discrete_cnn_module.py @@ -1,8 +1,9 @@ """Discrete CNN Q Function.""" +from dowel import tabular import torch from torch import nn -from garage.torch.modules import CNNModule, MLPModule +from garage.torch.modules import CNNModule, MLPModule, NoisyMLPModule # pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 @@ -33,6 +34,13 @@ class DiscreteCNNModule(nn.Module): of two hidden layers, each with 32 hidden units. dueling (bool): Whether to use a dueling architecture for the fully-connected layer. + noisy (bool): Whether to use parameter noise for the fully-connected + layers. If True, hidden_w_init, hidden_b_init, output_w_init, and + output_b_init are ignored. + noisy_sigma (float): Level of scaling to apply to the parameter noise. + This is ignored if noisy is set to False. + std_noise (float): Standard deviation of the gaussian parameters noise. + This is ignored if noisy is set to False. mlp_hidden_nonlinearity (callable): Activation function for intermediate dense layer(s) in the MLP. It should return a torch.Tensor. Set it to None to maintain a linear activation. @@ -81,6 +89,9 @@ def __init__(self, hidden_w_init=nn.init.xavier_uniform_, hidden_b_init=nn.init.zeros_, paddings=0, + noisy=True, + noisy_sigma=0.5, + std_noise=1., padding_mode='zeros', max_pool=False, pool_shape=None, @@ -94,6 +105,8 @@ def __init__(self, super().__init__() self._dueling = dueling + self._noisy = noisy + self._noisy_layers = None input_var = torch.zeros(input_shape) cnn_module = CNNModule(input_var=input_var, @@ -116,26 +129,49 @@ def __init__(self, flat_dim = torch.flatten(cnn_out, start_dim=1).shape[1] if dueling: - self._val = MLPModule(flat_dim, - 1, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) - self._act = MLPModule(flat_dim, - output_dim, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) + if noisy: + self._val = NoisyMLPModule( + flat_dim, + 1, + hidden_sizes, + sigma_naught=noisy_sigma, + std_noise=std_noise, + hidden_nonlinearity=mlp_hidden_nonlinearity, + output_nonlinearity=output_nonlinearity) + self._act = NoisyMLPModule( + flat_dim, + output_dim, + hidden_sizes, + sigma_naught=noisy_sigma, + std_noise=std_noise, + hidden_nonlinearity=mlp_hidden_nonlinearity, + output_nonlinearity=output_nonlinearity) + self._noisy_layers = [self._val, self._act] + else: + self._val = MLPModule( + flat_dim, + 1, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + + self._act = MLPModule( + flat_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) + if mlp_hidden_nonlinearity is None: self._module = nn.Sequential(cnn_module, nn.Flatten()) else: @@ -144,16 +180,29 @@ def __init__(self, nn.Flatten()) else: - mlp_module = MLPModule(flat_dim, - output_dim, - hidden_sizes, - hidden_nonlinearity=mlp_hidden_nonlinearity, - hidden_w_init=hidden_w_init, - hidden_b_init=hidden_b_init, - output_nonlinearity=output_nonlinearity, - output_w_init=output_w_init, - output_b_init=output_b_init, - layer_normalization=layer_normalization) + mlp_module = None + if noisy: + mlp_module = NoisyMLPModule( + flat_dim, + output_dim, + hidden_sizes, + sigma_naught=noisy_sigma, + std_noise=std_noise, + hidden_nonlinearity=mlp_hidden_nonlinearity, + output_nonlinearity=output_nonlinearity) + self._noisy_layers = [mlp_module] + else: + mlp_module = MLPModule( + flat_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=mlp_hidden_nonlinearity, + hidden_w_init=hidden_w_init, + hidden_b_init=hidden_b_init, + output_nonlinearity=output_nonlinearity, + output_w_init=output_w_init, + output_b_init=output_b_init, + layer_normalization=layer_normalization) if mlp_hidden_nonlinearity is None: self._module = nn.Sequential(cnn_module, nn.Flatten(), @@ -182,3 +231,21 @@ def forward(self, inputs): return val + act return self._module(inputs) + + def log_noise(self, key): + """Log sigma levels for noisy layers. + + Args: + key (str): Prefix to use for logging. + + """ + if self._noisy: + layer_num = 0 + for layer in self._noisy_layers: + for name, param in layer.named_parameters(): + if name.endswith('weight_sigma'): + layer_num += 1 + sigma_mean = float( + (param**2).mean().sqrt().data.cpu().numpy()) + tabular.record(key + '_layer_' + str(layer_num), + sigma_mean) diff --git a/src/garage/torch/modules/multi_headed_mlp_module.py b/src/garage/torch/modules/multi_headed_mlp_module.py index fcb4479744..7c625459bd 100644 --- a/src/garage/torch/modules/multi_headed_mlp_module.py +++ b/src/garage/torch/modules/multi_headed_mlp_module.py @@ -7,6 +7,8 @@ from garage.torch import NonLinearity +# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 +# pylint: disable=abstract-method class MultiHeadedMLPModule(nn.Module): """MultiHeadedMLPModule Model. @@ -71,8 +73,6 @@ def __init__(self, output_nonlinearities = self._check_parameter_for_output_layer( 'output_nonlinearities', output_nonlinearities, n_heads) - self._layers = nn.ModuleList() - prev_size = input_dim for size in hidden_sizes: hidden_layers = nn.Sequential() diff --git a/src/garage/torch/modules/noisy_mlp_module.py b/src/garage/torch/modules/noisy_mlp_module.py new file mode 100644 index 0000000000..f6c1d4fca5 --- /dev/null +++ b/src/garage/torch/modules/noisy_mlp_module.py @@ -0,0 +1,221 @@ +"""Noisy MLP Module.""" + +import math + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F + +from garage.torch import NonLinearity + + +# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305 +# pylint: disable=abstract-method +class NoisyMLPModule(nn.Module): + """MLP with factorised gaussian parameter noise. + + See https://arxiv.org/pdf/1706.10295.pdf. + + This module creates a multilayered perceptron (MLP) in where the + linear layers are replaced with :class:`~NoisyLinear` layers. + See the docstring of :class:`~NoisyLinear` and the linked paper + for more details. + + Args: + input_dim (int) : Dimension of the network input. + output_dim (int): Dimension of the network output. + hidden_sizes (list[int]): Output dimension of dense layer(s). + For example, (32, 32) means this MLP consists of two + hidden layers, each with 32 hidden units. + hidden_nonlinearity (callable or torch.nn.Module): Activation function + for intermediate dense layer(s). It should return a torch.Tensor. + Set it to None to maintain a linear activation. + sigma_naught (float): Hyperparameter that specifies the intial noise + scaling factor. See the paper for details. + std_noise (float): Standard deviation of the gaussian noise + distribution to sample from. + output_nonlinearity (callable or torch.nn.Module): Activation function + for output dense layer. It should return a torch.Tensor. + Set it to None to maintain a linear activation. + """ + + def __init__(self, + input_dim, + output_dim, + hidden_sizes, + hidden_nonlinearity=F.relu, + sigma_naught=0.5, + std_noise=1., + output_nonlinearity=None): + super().__init__() + self._layers = nn.ModuleList() + self._noisy_layers = [] + + prev_size = input_dim + for size in hidden_sizes: + hidden_layers = nn.Sequential() + linear_layer = NoisyLinear(prev_size, size, sigma_naught, + std_noise) + self._noisy_layers.append(linear_layer) + hidden_layers.add_module('linear', linear_layer) + + if hidden_nonlinearity: + hidden_layers.add_module('non_linearity', + NonLinearity(hidden_nonlinearity)) + + self._layers.append(hidden_layers) + prev_size = size + + self._output_layers = nn.ModuleList() + output_layer = nn.Sequential() + linear_layer = NoisyLinear(prev_size, output_dim, sigma_naught, + std_noise) + self._noisy_layers.append(linear_layer) + output_layer.add_module('linear', linear_layer) + + if output_nonlinearity: + output_layer.add_module('non_linearity', + NonLinearity(output_nonlinearity)) + + self._output_layers.append(output_layer) + + def forward(self, input_val): + """Forward method. + + Args: + input_val (torch.Tensor): Input values with (N, *, input_dim) + shape. + + Returns: + List[torch.Tensor]: Output values + + """ + x = input_val + for layer in self._layers: + x = layer(x) + + return self._output_layers[0](x) + + def set_deterministic(self, deterministic): + """Set whether or not noise is applied. + + This is useful when determinstic evaluation of + a policy is desired. Determinism is off by default. + + Args: + deterministic (bool): If False, noise is applied, else + it is not. + """ + for layer in self._noisy_layers: + layer.set_deterministic(deterministic) + + +class NoisyLinear(nn.Module): + r"""Noisy linear layer with Factorised Gaussian noise. + + See https://arxiv.org/pdf/1706.10295.pdf. + + Each NoisyLinear layer applies the following transformation + + :math:`y = (\mu^w + \sigma^w \odot \epsilon ^w) + \mu^b + \sigma^b \odot + \epsilon^b` + + where :math:`\mu^w, \mu^b, \sigma^w, and \sigma^b` are learned parameters + and :math:`\epislon^w, \epsilon^b` are zero-mean gaussian noise samples. + + Args: + input_dim (int) : Dimension of the network input. + output_dim (int): Dimension of the network output. + sigma_naught (float): Hyperparameter that specifies the intial noise + scaling factor. See the paper for details. + std_noise (float): Standard deviation of the gaussian noise + distribution to sample from. + """ + + def __init__(self, input_dim, output_dim, sigma_naught=0.5, std_noise=1.): + super().__init__() + self._input_dim = input_dim + self._output_dim = output_dim + self._sigma_naught = sigma_naught + self._std_noise = std_noise + self._deterministic = False + + self._output_dim = output_dim + + self._weight_mu = nn.Parameter(torch.FloatTensor( + output_dim, input_dim)) + self._weight_sigma = nn.Parameter( + torch.FloatTensor(output_dim, input_dim)) + + self._bias_mu = nn.Parameter(torch.FloatTensor(output_dim)) + self._bias_sigma = nn.Parameter(torch.FloatTensor(output_dim)) + + # epsilon noise + self.register_buffer('weight_epsilon', + torch.FloatTensor(output_dim, input_dim)) + self.register_buffer('bias_epsilon', torch.FloatTensor(output_dim)) + + self.reset_parameters() + + def set_deterministic(self, deterministic): + """Set whether or not noise is applied. + + This is useful when determinstic evaluation of + a policy is desired. Determinism is off by default. + + Args: + deterministic (bool): If False, noise is applied, else + it is not. + """ + self._deterministic = deterministic + + def forward(self, input_value): + """Forward method. + + Args: + input_value (torch.Tensor): Input values with (N, *, input_dim) + shape. + + Returns: + torch.Tensor: Output value + """ + if self._deterministic: + return F.linear(input_value, self._weight_mu, self._bias_mu) + + self._sample_noise() + w = self._weight_mu + self._weight_sigma.mul( + Variable(self.weight_epsilon)) + b = self._bias_mu + self._bias_sigma.mul(Variable(self.bias_epsilon)) + return F.linear(input_value, w, b) + + def reset_parameters(self): + """Reset all learnable parameters.""" + mu_range = 1 / math.sqrt(self._weight_mu.size(1)) + + self._weight_mu.data.uniform_(-mu_range, mu_range) + self._weight_sigma.data.fill_(self._sigma_naught / + math.sqrt(self._weight_sigma.size(1))) + + self._bias_mu.data.uniform_(-mu_range, mu_range) + self._bias_sigma.data.fill_(self._sigma_naught / + math.sqrt(self._bias_sigma.size(0))) + + def _sample_noise(self): + r"""Sample and assign new values for :math:`\epsilon`.""" + in_noise = self._get_noise(self._input_dim) + out_noise = self._get_noise(self._output_dim) + self.weight_epsilon.copy_(out_noise.ger(in_noise)) + self.bias_epsilon.copy_(self._get_noise(self._output_dim)) + + def _get_noise(self, size): + """Retrieve scaled zero-mean gaussian noise. + + Args: + size (int): size of the noise vector. + + Returns: + torch.Tensor: noise vector of the specified size. + """ + x = torch.normal(torch.zeros(size), self._std_noise * torch.ones(size)) + return x.sign().mul(x.abs().sqrt()) diff --git a/src/garage/torch/q_functions/discrete_cnn_q_function.py b/src/garage/torch/q_functions/discrete_cnn_q_function.py index 02a5987dba..47d54f5967 100644 --- a/src/garage/torch/q_functions/discrete_cnn_q_function.py +++ b/src/garage/torch/q_functions/discrete_cnn_q_function.py @@ -29,6 +29,13 @@ class DiscreteCNNQFunction(DiscreteCNNModule): and the second one outputs 32 channels. dueling (bool): Whether to use a dueling architecture for the fully-connected layer. + noisy (bool): Whether to use parameter noise for the fully-connected + layers. If True, hidden_w_init, hidden_b_init, output_w_init, and + output_b_init are ignored. + noisy_sigma (float): Level of scaling to apply to the parameter noise. + This is ignored if noisy is set to False. + std_noise (float): Standard deviation of the gaussian parameters noise. + This is ignored if noisy is set to False. hidden_sizes (list[int]): Output dimension of dense layer(s) for the MLP for mean. For example, (32, 32) means the MLP consists of two hidden layers, each with 32 hidden units. @@ -73,6 +80,9 @@ def __init__(self, hidden_channels, strides, dueling=False, + noisy=False, + noisy_sigma=0.5, + std_noise=1., hidden_sizes=(32, 32), cnn_hidden_nonlinearity=torch.nn.ReLU, mlp_hidden_nonlinearity=torch.nn.ReLU, @@ -98,6 +108,9 @@ def __init__(self, strides=strides, hidden_sizes=hidden_sizes, dueling=dueling, + noisy=noisy, + noisy_sigma=noisy_sigma, + std_noise=std_noise, hidden_channels=hidden_channels, cnn_hidden_nonlinearity=cnn_hidden_nonlinearity, mlp_hidden_nonlinearity=mlp_hidden_nonlinearity, diff --git a/tests/garage/torch/algos/test_dqn.py b/tests/garage/torch/algos/test_dqn.py index 1f40761bd1..ea3c157d7a 100644 --- a/tests/garage/torch/algos/test_dqn.py +++ b/tests/garage/torch/algos/test_dqn.py @@ -14,6 +14,9 @@ from garage.replay_buffer import PathBuffer from garage.sampler import LocalSampler from garage.torch import np_to_torch +from garage.torch.algos import DQN +from garage.torch.policies import DiscreteQFArgmaxPolicy +from garage.torch.q_functions import DiscreteMLPQFunction from garage.trainer import Trainer from tests.fixtures import snapshot_config @@ -26,6 +29,7 @@ def setup(): steps_per_epoch = 10 sampler_batch_size = 512 num_timesteps = 100 * steps_per_epoch * sampler_batch_size + env = GymEnv('CartPole-v0') replay_buffer = PathBuffer(capacity_in_transitions=int(1e6)) diff --git a/tests/garage/torch/modules/test_noisy_mlp_module.py b/tests/garage/torch/modules/test_noisy_mlp_module.py new file mode 100644 index 0000000000..75eaa94eee --- /dev/null +++ b/tests/garage/torch/modules/test_noisy_mlp_module.py @@ -0,0 +1,123 @@ +"""Test NoisyMLPModule.""" +import math + +import numpy as np +import pytest +import torch +from torch.autograd import Variable +from torch.nn import functional as F + +from garage.torch.modules import NoisyMLPModule +from garage.torch.modules.noisy_mlp_module import NoisyLinear + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught, hidden_sizes', + [(1, 1, 0.1, (32, 32)), (2, 2, 0.5, (32, 64)), + (2, 3, 1., (5, 5, 5))]) +def test_forward(input_dim, output_dim, sigma_naught, hidden_sizes): + noisy_mlp = NoisyMLPModule(input_dim, + output_dim, + hidden_nonlinearity=None, + sigma_naught=sigma_naught, + hidden_sizes=hidden_sizes) + + # mock the noise + for layer in noisy_mlp._noisy_layers: + layer._get_noise = lambda x: torch.Tensor(np.ones(x)) + + input_val = torch.Tensor(np.ones(input_dim, dtype=np.float32)) + x = input_val + for layer in noisy_mlp._noisy_layers: + x = layer(x) + + out = noisy_mlp.forward(input_val) + assert (x == out).all() + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught, hidden_sizes', + [(1, 1, 0.1, (32, 32)), (2, 2, 0.5, (32, 64)), + (2, 3, 1., (5, 5, 5))]) +def test_forward_deterministic(input_dim, output_dim, sigma_naught, + hidden_sizes): + noisy_mlp = NoisyMLPModule(input_dim, + output_dim, + hidden_nonlinearity=None, + sigma_naught=sigma_naught, + hidden_sizes=hidden_sizes) + noisy_mlp.set_deterministic(True) + input_val = torch.Tensor(np.ones(input_dim, dtype=np.float32)) + x = input_val + for layer in noisy_mlp._noisy_layers: + x = layer(x) + + out = noisy_mlp.forward(input_val) + assert (x == out).all() + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught', [(1, 1, 0.1), + (2, 2, 0.5), + (2, 3, 1.)]) +def test_noisy_linear_reset_parameters(input_dim, output_dim, sigma_naught): + layer = NoisyLinear(input_dim, output_dim, sigma_naught=0) + + mu_range = 1 / math.sqrt(input_dim) + assert (layer._weight_sigma == 0.).all() + assert (layer._bias_sigma == 0.).all() + + layer._sigma_naught = sigma_naught + layer.reset_parameters() + + bias_sig = sigma_naught / math.sqrt(output_dim) + weight_sig = sigma_naught / math.sqrt(input_dim) + + # sigma + assert (layer._weight_sigma == weight_sig).all() + assert (layer._bias_sigma == bias_sig).all() + + # mu + assert ((layer._bias_mu <= mu_range).all() + and (layer._bias_mu >= -mu_range).all()) + + assert ((layer._weight_mu <= mu_range).all() + and (layer._weight_mu >= -mu_range).all()) + + +@pytest.mark.parametrize('input_dim, output_dim, sigma_naught', [(1, 1, 0.1), + (2, 2, 0.5), + (2, 3, 1.)]) +def test_noisy_linear_forward(input_dim, output_dim, sigma_naught): + layer = NoisyLinear(input_dim, output_dim, sigma_naught=sigma_naught) + + input_val = torch.Tensor(np.ones(input_dim, dtype=np.float32)) + val = layer.forward(input_val).detach() + w = layer._weight_mu + \ + layer._weight_sigma.mul(Variable(layer.weight_epsilon)) + b = layer._bias_mu + layer._bias_sigma.mul(Variable(layer.bias_epsilon)) + + expected = F.linear(input_val, w, b).detach() + + assert (val == expected).all() + + # test deterministic mode + + layer.set_deterministic(True) + val = layer.forward(input_val) + expected = F.linear(input_val, layer._weight_mu, layer._bias_mu) + + assert (val == expected).all() + + +@pytest.mark.parametrize('input_dim, output_dim', [(1, 1), (2, 2), (2, 3)]) +def test_sample_noise(input_dim, output_dim): + layer = NoisyLinear(input_dim, output_dim) + + # mock the noise + layer._get_noise = lambda x: torch.Tensor(np.ones(x)) + + out_noise = layer._get_noise(output_dim) + in_noise = layer._get_noise(input_dim) + layer._sample_noise() + + expected = out_noise.ger(in_noise).detach() + assert (layer.weight_epsilon == expected).all() + assert (layer.bias_epsilon == layer._get_noise(output_dim)).all()