Skip to content

Commit

Permalink
add an option to perturb gradient of each data point during optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
runjerry committed Nov 3, 2024
1 parent 2ffb9fb commit 633bafe
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
20 changes: 17 additions & 3 deletions alf/algorithms/oaec_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ def __init__(self,
config: TrainerConfig = None,
critic_loss_ctor=None,
num_rollout_sampled_actions=10,
num_sampled_target_q_actions=0,
num_bootstrap_critics=1,
critic_replicas_deepcopy=True,
bootstrap_mask_prob=0.8,
opt_ptb_single_data=True,
opt_ptb_dist="exponential",
# correct_optimization_noise=False,
# align_optimization_noise=False,
Expand Down Expand Up @@ -158,13 +160,18 @@ def __init__(self,
only useful if priority replay is enabled.
num_rollout_sampled_actions (int): number of sampled actions in rollout.
The one with the highest Q_value + epistemic_std will be executed.
num_sampled_target_q_actions (int): number of sampled actions for target
critics, default is 0, indicating no sampling, i.e., using the mean
of the policy output.
num_bootstrap_critics (int): a positive number of bootstrapped critics
for uncertainty estimation. Default is 1.
critic_replicas_deepcopy (bool): whether to deepcopy the critic_network
for replicas. Default is False, meaning that each critic_replica
will have different independently instantiated parameters.
bootstrap_mask_prob (float): the parameter of the Binomial distribution
for independently masking out a transition to simulate bootstrapping.
opt_ptb_single_data (bool): whether to perturb each training data
during optimization.
opt_ptb_dist (str): the distribution for sampling optimization perturbation.
Options are ["exponential", "uniform"].
correct_optimization_noise (bool): whether to correct the optimization
Expand Down Expand Up @@ -246,6 +253,7 @@ def __init__(self,
self._std_for_overestimate = std_for_overestimate
self._num_rollout_sampled_actions = num_rollout_sampled_actions
self._bootstrap_mask_prob = bootstrap_mask_prob
self._opt_ptb_single_data = opt_ptb_single_data
self._opt_ptb_dist = opt_ptb_dist
# self._opt_ptb_dist = torch.distributions.Exponential(1.0)
self._device = alf.get_default_device()
Expand Down Expand Up @@ -320,8 +328,11 @@ def __init__(self,
'TrainerConfig.mini_batch_size')
self._mini_batch_length = alf.get_config_value(
'TrainerConfig.mini_batch_length')
self._opt_ptb_weights = torch.empty(
(self._mini_batch_length, self._mini_batch_size, self._num_opt_ptb_critics))
if opt_ptb_single_data:
self._opt_ptb_weights = torch.empty(
(self._mini_batch_length, self._mini_batch_size, self._num_opt_ptb_critics))
else:
self._opt_ptb_weights = torch.empty((self._num_opt_ptb_critics,))

def _predict_action(self,
observation,
Expand Down Expand Up @@ -585,7 +596,10 @@ def calc_loss(self, info: OaecInfo):
self._opt_ptb_dist.uniform_(0.5, 1.5)
n_start = 1 + self._num_bootstrap_critics
for i in range(self._num_opt_ptb_critics):
weights = self._opt_ptb_weights[:, :, i]
if self._opt_ptb_single_data:
weights = self._opt_ptb_weights[:, :, i]
else:
weights = self._opt_ptb_weights[i]
critic_losses[n_start + i] = weights * self._critic_losses[n_start + i](
info=info,
value=info.critic.q_values[:, :, n_start + i, ...],
Expand Down
5 changes: 3 additions & 2 deletions alf/examples/oaec_dmc.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"desc": [
"DM Control tasks with 4 seeds on each environment"
],
"version": "hopper_oaec_0e292_nb_2-ces",
"version": "hopper_oaec_2ffb9_nb_2-ces",
"use_gpu": true,
"gpus": [
0, 1
Expand All @@ -17,8 +17,9 @@
"OaecAlgorithm.beta_lb": "[.1]",
"OaecAlgorithm.output_target_critic": "[True]",
"OaecAlgorithm.std_for_overestimate": "['opt']",
"OaecAlgorithm.opt_ptb_single_data": "[False]",
"OaecAlgorithm.num_bootstrap_critics": "[2]",
"OaecAlgorithm.bootstrap_mask_prob": "[.8]",
"OaecAlgorithm.bootstrap_mask_prob": "[0.5, 0.8]",
"OaecAlgorithm.use_target_actor": "[False]",
"OaecAlgorithm.target_update_tau": "[0.005]",
"TrainerConfig.random_seed": "list(range(2))"
Expand Down
1 change: 1 addition & 0 deletions alf/examples/oaec_dmc_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
beta_lb=.5,
output_target_critic=True,
std_for_overestimate='opt',
opt_ptb_single_data=True,
reward_noise_scale=None,
num_rollout_sampled_actions=10,
num_bootstrap_critics=2,
Expand Down

0 comments on commit 633bafe

Please sign in to comment.