From b05c856da4e8b0c61b804ddbf616d1148388cec0 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Mon, 6 Nov 2023 11:37:22 +0800 Subject: [PATCH 01/16] add action --- ding/policy/plan_diffuser.py | 2 +- ding/utils/data/dataset.py | 2 ++ dizoo/d4rl/config/antmaze_umaze_pd_config.py | 2 +- dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py | 2 +- dizoo/d4rl/config/halfcheetah_medium_pd_config.py | 2 +- dizoo/d4rl/config/hopper_medium_expert_pd_config.py | 2 +- dizoo/d4rl/config/hopper_medium_pd_config.py | 2 +- dizoo/d4rl/config/walker2d_medium_expert_pd_config.py | 2 +- dizoo/d4rl/config/walker2d_medium_pd_config.py | 2 +- 9 files changed, 10 insertions(+), 8 deletions(-) diff --git a/ding/policy/plan_diffuser.py b/ding/policy/plan_diffuser.py index 7e6854789f..ad58546a15 100755 --- a/ding/policy/plan_diffuser.py +++ b/ding/policy/plan_diffuser.py @@ -178,7 +178,7 @@ def _init_learn(self) -> None: self.step_start_update_target = self._cfg.learn.step_start_update_target self.target_weight = self._cfg.learn.target_weight self.value_step = self._cfg.learn.value_step - self.use_target = True + self.use_target = False self.horizon = self._cfg.model.diffuser_model_cfg.horizon self.include_returns = self._cfg.learn.include_returns diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 013f8b688d..23db0fcdf9 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1090,11 +1090,13 @@ def __getitem__(self, idx, eps=1e-4): 'trajectories': trajectories, 'returns': returns, 'done': done, + 'action': actions, } else: batch = { 'trajectories': trajectories, 'done': done, + 'action': actions, } batch.update(self.get_conditions(observations)) diff --git a/dizoo/d4rl/config/antmaze_umaze_pd_config.py b/dizoo/d4rl/config/antmaze_umaze_pd_config.py index 8dadd63a13..96ca022545 100755 --- a/dizoo/d4rl/config/antmaze_umaze_pd_config.py +++ b/dizoo/d4rl/config/antmaze_umaze_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=37, dim=32, diff --git a/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py b/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py index d3a0dbb4b8..66c8ba8d91 100755 --- a/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/config/halfcheetah_medium_pd_config.py b/dizoo/d4rl/config/halfcheetah_medium_pd_config.py index 2386c278ec..674395a4e1 100755 --- a/dizoo/d4rl/config/halfcheetah_medium_pd_config.py +++ b/dizoo/d4rl/config/halfcheetah_medium_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py index 6205018751..3df47f8d1b 100755 --- a/dizoo/d4rl/config/hopper_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_expert_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=14, dim=32, diff --git a/dizoo/d4rl/config/hopper_medium_pd_config.py b/dizoo/d4rl/config/hopper_medium_pd_config.py index 49caaec5d2..8dfee5d824 100755 --- a/dizoo/d4rl/config/hopper_medium_pd_config.py +++ b/dizoo/d4rl/config/hopper_medium_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=14, dim=32, diff --git a/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py b/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py index 18cb45559b..3d4c060e83 100755 --- a/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py +++ b/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, diff --git a/dizoo/d4rl/config/walker2d_medium_pd_config.py b/dizoo/d4rl/config/walker2d_medium_pd_config.py index 8b2c0b0a4a..29fce259c8 100755 --- a/dizoo/d4rl/config/walker2d_medium_pd_config.py +++ b/dizoo/d4rl/config/walker2d_medium_pd_config.py @@ -24,7 +24,7 @@ model=dict( diffuser_model='GaussianDiffusion', diffuser_model_cfg=dict( - model='TemporalUnet', + model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=32, From a459fd054f987ff07498a116fa65ac72287235dc Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Mon, 6 Nov 2023 11:40:01 +0800 Subject: [PATCH 02/16] change entry --- dizoo/d4rl/entry/d4rl_pd_main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dizoo/d4rl/entry/d4rl_pd_main.py b/dizoo/d4rl/entry/d4rl_pd_main.py index 73e08288ed..1ca3c5b299 100755 --- a/dizoo/d4rl/entry/d4rl_pd_main.py +++ b/dizoo/d4rl/entry/d4rl_pd_main.py @@ -16,6 +16,6 @@ def train(args): parser = argparse.ArgumentParser() parser.add_argument('--seed', '-s', type=int, default=10) - parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py') + parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py') args = parser.parse_args() train(args) \ No newline at end of file From e97725ccbf07679712d1f1695ec86a98ac83da59 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Wed, 20 Dec 2023 15:34:56 +0800 Subject: [PATCH 03/16] add meta diffusion and prompt dt --- ding/model/template/decision_transformer.py | 44 ++- ding/model/template/diffusion.py | 214 +++++++++++- ding/policy/meta_diffuser.py | 347 ++++++++++++++++++++ ding/policy/prompt_dt.py | 192 +++++++++++ ding/torch_utils/network/diffusion.py | 18 +- ding/utils/data/dataset.py | 273 +++++++++++++++ 6 files changed, 1078 insertions(+), 10 deletions(-) mode change 100644 => 100755 ding/model/template/decision_transformer.py create mode 100755 ding/policy/meta_diffuser.py create mode 100755 ding/policy/prompt_dt.py diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py old mode 100644 new mode 100755 index 3d35497383..fb01c9a0b0 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -176,7 +176,8 @@ def __init__( drop_p: float, max_timestep: int = 4096, state_encoder: Optional[nn.Module] = None, - continuous: bool = False + continuous: bool = False, + use_prompt: bool = False, ): """ Overview: @@ -206,6 +207,8 @@ def __init__( # projection heads (project to embedding) self.embed_ln = nn.LayerNorm(h_dim) self.embed_timestep = nn.Embedding(max_timestep, h_dim) + if use_prompt: + self.prompt_embed_timestep = nn.Embedding(max_timestep, h_dim) self.drop = nn.Dropout(drop_p) self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) @@ -218,14 +221,21 @@ def __init__( self.embed_state = torch.nn.Linear(state_dim, h_dim) self.predict_rtg = torch.nn.Linear(h_dim, 1) self.predict_state = torch.nn.Linear(h_dim, state_dim) + if use_prompt: + self.prompt_embed_state = torch.nn.Linear(state_dim, h_dim) + self.prompt_embed_rtg = torch.nn.Linear(1, h_dim) if continuous: # continuous actions self.embed_action = torch.nn.Linear(act_dim, h_dim) use_action_tanh = True # True for continuous actions + if use_prompt: + self.prompt_embed_action = torch.nn.Linear(act_dim, h_dim) else: # discrete actions self.embed_action = torch.nn.Embedding(act_dim, h_dim) use_action_tanh = False # False for discrete actions + if use_prompt: + self.prompt_embed_action = torch.nn.Embedding(act_dim, h_dim) self.predict_action = nn.Sequential( *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) ) @@ -243,7 +253,8 @@ def forward( states: torch.Tensor, actions: torch.Tensor, returns_to_go: torch.Tensor, - tar: Optional[int] = None + tar: Optional[int] = None, + prompt: dict = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Overview: @@ -299,6 +310,35 @@ def forward( t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) h = self.embed_ln(t_p) + + if prompt is not None: + prompt_states, prompt_actions, prompt_returns_to_go,\ + prompt_timesteps, prompt_attention_mask = prompt + prompt_seq_length = prompt_states.shape[1] + prompt_state_embeddings = self.prompt_embed_state(prompt_states) + prompt_action_embeddings = self.prompt_embed_action(prompt_actions) + if prompt_returns_to_go.shape[1] % 10 == 1: + prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go[:,:-1]) + else: + prompt_returns_embeddings = self.prompt_embed_rtg(prompt_returns_to_go) + prompt_time_embeddings = self.prompt_embed_timestep(prompt_timesteps) + + prompt_state_embeddings = prompt_state_embeddings + prompt_time_embeddings + prompt_action_embeddings = prompt_action_embeddings + prompt_time_embeddings + prompt_returns_embeddings = prompt_returns_embeddings + prompt_time_embeddings + prompt_stacked_attention_mask = torch.stack( + (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 + ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length) + + if prompt_stacked_inputs.shape[1] == 3 * T: # if only smaple one prompt + prompt_stacked_inputs = prompt_stacked_inputs.reshape(1, -1, self.hidden_size) + prompt_stacked_attention_mask = prompt_stacked_attention_mask.reshape(1, -1) + h = torch.cat((prompt_stacked_inputs.repeat(B, 1, 1), h), dim=1) + stacked_attention_mask = torch.cat((prompt_stacked_attention_mask.repeat(B, 1), stacked_attention_mask), dim=1) + else: # if sample one prompt for each traj in batch + h = torch.cat((prompt_stacked_inputs, h), dim=1) + stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) + # transformer and prediction h = self.transformer(h) # get h reshaped such that its size = (B x 3 x T x h_dim) and diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index f8b48f3061..934e3d7c73 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -26,9 +26,12 @@ def default_sample_fn(model, x, cond, t): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, values -def get_guide_output(guide, x, cond, t): +def get_guide_output(guide, x, cond, t, returns=None): x.requires_grad_() - y = guide(x, cond, t).squeeze(dim=-1) + if returns: + y = guide(x, cond, t, returns).squeeze(dim=-1) + else: + y = guide(x, cond, t).squeeze(dim=-1) grad = torch.autograd.grad([y.sum()], [x])[0] x.detach() return y, grad @@ -69,6 +72,50 @@ def n_step_guided_p_sample( return model_mean + model_std * noise, y +def free_guidance_sample( + model, + x, + cond, + t, + guide1, + guide2, + returns=None, + scale=1, + t_stopgrad=0, + n_guide_steps=1, + scale_grad_by_std=True, + +): + weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) + model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) + model_std = torch.exp(0.5 * model_log_variance) + model_var = torch.exp(model_log_variance) + + for _ in range(n_guide_steps): + with torch.enable_grad(): + y1, grad1 = get_guide_output(guide1, x, cond, t, returns) # get reward + y2, grad2 = get_guide_output(guide2, x, cond, t) # get state + grad = grad1 + scale * grad2 + + if scale_grad_by_std: + grad = weight * grad + + grad[t < t_stopgrad] = 0 + + if returns: + # epsilon could be epsilon or x0 itself + epsilon_cond = model.model(x, cond, t, returns, use_dropout=False) + epsilon_uncond = model.model(x, cond, t, returns, force_dropout=True) + epsilon = epsilon_uncond + model.condition_guidance_w * (epsilon_cond - epsilon_uncond) + else: + epsilon = model.model(x, cond, t) + epsilon += grad + + model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, epsilon=epsilon) + noise = torch.randn_like(x) + noise[t == 0] = 0 + + return model_mean + model_std * noise, class GaussianDiffusion(nn.Module): """ @@ -299,12 +346,12 @@ class ValueDiffusion(GaussianDiffusion): Gaussian diffusion model for value function. """ - def p_losses(self, x_start, cond, target, t): + def p_losses(self, x_start, cond, target, t, returns=None): noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) - pred = self.model(x_noisy, cond, t) + pred = self.model(x_noisy, cond, t, returns) loss = F.mse_loss(pred, target, reduction='none').mean() log = { 'mean_pred': pred.mean().item(), @@ -314,8 +361,8 @@ def p_losses(self, x_start, cond, target, t): return loss, log - def forward(self, x, cond, t): - return self.model(x, cond, t) + def forward(self, x, cond, t, returns=None): + return self.model(x, cond, t, returns) @MODEL_REGISTRY.register('pd') @@ -643,3 +690,158 @@ def p_losses(self, x_start, cond, t, returns=None): def forward(self, cond, *args, **kwargs): return self.conditional_sample(cond=cond, *args, **kwargs) + +class GuidenceFreeDifffuser(GaussianInvDynDiffusion): + + def p_mean_variance(self, x, cond, t, epsilon): + x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) + + if self.clip_denoised: + x_recon.clamp_(-1., 1.) + else: + assert RuntimeError() + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + def p_sample_loop(self, shape, cond, sample_fn=None, plan_size=1, **sample_kwargs): + device = self.betas.device + + batch_size = shape[0] + x = torch.randn(shape, device=device) + x = apply_conditioning(x, cond, self.action_dim) + + assert sample_fn != None + for i in reversed(range(0, self.n_timesteps)): + t = torch.full((batch_size, ), i, device=device, dtype=torch.long) + x = sample_fn(self, x, cond, t, **sample_kwargs) + x = apply_conditioning(x, cond, self.action_dim) + + return x + + + def conditional_sample(self, cond, horizon=None, **sample_kwargs): + device = self.betas.device + batch_size = len(cond[0]) + horizon = horizon or self.horizon + shape = (batch_size, horizon, self.obs_dim) + return self.p_sample_loop(shape, cond, **sample_kwargs) + + def p_losses(self, x_start, cond, t, returns=None): + noise = torch.randn_like(x_start) + + batch_size = len(cond[0]) + mask_rand = torch.rand([batch_size]) + mask = torch.bernoulli(mask_rand, 0.7) + returns = returns * mask + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_noisy = apply_conditioning(x_noisy, cond, 0) + + x_recon = self.model(x_noisy, cond, t, returns) + + if not self.predict_epsilon: + x_recon = apply_conditioning(x_recon, cond, 0) + + assert noise.shape == x_recon.shape + + if self.predict_epsilon: + loss = F.mse_loss(x_recon, noise, reduction='none') + a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean() + loss = (loss * self.loss_weights.to(loss.device)).mean() + else: + loss = F.mse_loss(x_recon, x_start, reduction='none') + a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean() + loss = (loss * self.loss_weights.to(loss.device)).mean() + return loss, a0_loss + + +@MODEL_REGISTRY.register('metadiffuser') +class MetaDiffuser(nn.Module): + + def __init__( + self, + dim: int, + obs_dim: Union[int, SequenceType], + action_dim: Union[int, SequenceType], + reward_cfg: dict, + diffuser_cfg: dict, + horizon: int, + **sample_kwargs, + ): + + self.obs_dim = obs_dim + self.action_dim = action_dim + self.horizon = horizon + self.sample_kwargs = sample_kwargs + + self.embed = nn.Sequential( + nn.Linear((obs_dim * 2 + action_dim + 1) * horizon, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim) + ) + + self.reward_model = ValueDiffusion(**reward_cfg) + + self.dynamic_model = nn.Sequential( + nn.Linear(obs_dim + action_dim + dim, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, 200), + nn.ReLU(), + nn.Linear(200, obs_dim), + ) + + self.diffuser = GuidenceFreeDifffuser(**diffuser_cfg) + + def diffuser_loss(self, x_start, cond, t): + return self.diffuser.p_losses(x_start, cond, t) + + def pre_train_loss(self, traj, target, t, cond): + input_emb = traj.reshape(target.shape[0], -1) + task_idx = self.embed(input_emb) + + states = traj[:,:, self.action_dim:self.action_dim + self.obs_dim] + actions = traj[:, :, :self.action_dim] + input = torch.cat([actions, states], dim=-1) + target_reward = target[:,-1] + + target_next_state = target[:, :-1] + task_idxs = torch.full(states.shape[:-1], task_idx, device=task_idx.device, dtype=torch.long) + + reward_loss, reward_log = self.reward_model.p_losses(input, cond, target_reward, t, task_idxs) + + + + n = states.shape[1] + + state_loss = 0 + for i in range(n): + next_state = self.dynamic_model(input) + state_loss += F.mse_loss(next_state, target_next_state, reduction='none').mean() + state_loss /= n + return state_loss, reward_loss + + def get_eval(self, cond, id, batch_size = 1): + if batch_size > 1: + cond = self.repeat_cond(cond, batch_size) + + samples = self.diffuser(cond, returns=id, sample_fn=free_guidance_sample, plan_size=batch_size, + guide1=self.reward_model, guide2=self.dynamic_model **self.sample_kwargs) + return samples[:, 0, :,self.action_dim] + + def repeat_cond(self, cond, batch_size): + for k, v in cond.items(): + cond[k] = v.repeat_interleave(batch_size, dim=0) + return cond diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py new file mode 100755 index 0000000000..48202e618f --- /dev/null +++ b/ding/policy/meta_diffuser.py @@ -0,0 +1,347 @@ +from typing import List, Dict, Any, Optional, Tuple, Union +from collections import namedtuple, defaultdict +import copy +import numpy as np +import torch +import torch.nn.functional as F +from torch.distributions import Normal, Independent + +from ding.torch_utils import Adam, to_device +from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \ + qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data +from ding.policy import Policy +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY, DatasetNormalizer +from ding.utils.data import default_collate, default_decollate +from .common_utils import default_preprocess_learn + +@POLICY_REGISTRY.register('metadiffuser') +class MDPolicy(Policy): + r""" + Overview: + Implicit Meta Diffuser + https://arxiv.org/pdf/2305.19923.pdf + + """ + config = dict( + type='pd', + # (bool) Whether to use cuda for network. + cuda=False, + # (bool type) priority: Determine whether to use priority in buffer sample. + # Default False in SAC. + priority=False, + # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True. + priority_IS_weight=False, + # (int) Number of training samples(randomly collected) in replay buffer when training starts. + # Default 10000 in SAC. + random_collect_size=10000, + nstep=1, + # normalizer type + normalizer='GaussianNormalizer', + model=dict( + diffuser_cfg=dict( + # the type of model + model='TemporalUnet', + # config of model + model_cfg=dict( + # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim + transition_dim=23, + dim=32, + dim_mults=[1, 2, 4, 8], + # whether use return as a condition + returns_condition=True, + condition_dropout=0.1, + # whether use calc energy + calc_energy=False, + kernel_size=5, + # whether use attention + attention=False, + ), + # horizon of tarjectory which generated by model + horizon=80, + # timesteps of diffusion + n_timesteps=1000, + # hidden dim of action model + # Whether predict epsilon + predict_epsilon=True, + # discount of loss + loss_discount=1.0, + # whether clip denoise + clip_denoised=False, + action_weight=10, + ), + value_model='ValueDiffusion', + value_model_cfg=dict( + # the type of model + model='TemporalValue', + # config of model + model_cfg=dict( + horizon=4, + # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim + transition_dim=23, + dim=32, + dim_mults=[1, 2, 4, 8], + # whether use calc energy + kernel_size=5, + ), + # horizon of tarjectory which generated by model + horizon=80, + # timesteps of diffusion + n_timesteps=1000, + # hidden dim of action model + predict_epsilon=True, + # discount of loss + loss_discount=1.0, + # whether clip denoise + clip_denoised=False, + action_weight=1.0, + ), + # guide_steps for p sample + n_guide_steps=2, + # scale of grad for p sample + scale=1, + # t of stopgrad for p sample + t_stopgrad=2, + # whether use std as a scale for grad + scale_grad_by_std=True, + ), + learn=dict( + + # How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + update_per_collect=1, + # (int) Minibatch size for gradient descent. + batch_size=100, + + # (float type) learning_rate_q: Learning rate for model. + # Default to 3e-4. + # Please set to 1e-3, when model.value_network is True. + learning_rate=3e-4, + # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum) + # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers. + # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000. + # However, interaction with HalfCheetah always gets done with done is False, + # Since we inplace done==True with done==False to keep + # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``), + # when the episode step is greater than max episode step. + ignore_done=False, + + # (float type) target_theta: Used for soft update of the target network, + # aka. Interpolation factor in polyak averaging for target networks. + # Default to 0.005. + target_theta=0.005, + # (float) discount factor for the discounted sum of rewards, aka. gamma. + discount_factor=0.99, + gradient_accumulate_every=2, + # train_epoch = train_epoch * gradient_accumulate_every + train_epoch=60000, + # batch_size of every env when eval + plan_batch_size=64, + + # step start update target model and frequence + step_start_update_target=2000, + update_target_freq=10, + # update weight of target net + target_weight=0.995, + value_step=200e3, + + # dataset weight include returns + include_returns=True, + + # (float) Weight uniform initialization range in the last output layer + init_w=3e-3, + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'md', ['ding.model.template.diffusion'] + + def _init_learn(self) -> None: + r""" + Overview: + Learn mode init method. Called by ``self.__init__``. + Init q, value and policy's optimizers, algorithm config, main and target models. + """ + # Init + self._priority = self._cfg.priority + self._priority_IS_weight = self._cfg.priority_IS_weight + self.action_dim = self._cfg.model.diffuser_model_cfg.action_dim + self.obs_dim = self._cfg.model.diffuser_model_cfg.obs_dim + self.n_timesteps = self._cfg.model.diffuser_model_cfg.n_timesteps + self.gradient_accumulate_every = self._cfg.learn.gradient_accumulate_every + self.plan_batch_size = self._cfg.learn.plan_batch_size + self.gradient_steps = 1 + self.update_target_freq = self._cfg.learn.update_target_freq + self.step_start_update_target = self._cfg.learn.step_start_update_target + self.target_weight = self._cfg.learn.target_weight + self.value_step = self._cfg.learn.value_step + self.horizon = self._cfg.model.diffuser_model_cfg.horizon + self.include_returns = self._cfg.learn.include_returns + self.eval_batch_size = self._cfg.learn.eval_batch_size + self.warm_batch_size = self._cfg.learn.warm_batch_size + + self._plan_optimizer = Adam( + self._model.diffuser.model.parameters(), + lr=self._cfg.learn.learning_rate, + ) + + self._pre_train_optimizer = Adam( + list(self._model.reward_model.model.parameters()) + list(self._model.embed.parameters()) \ + + list(self._model.dynamic_model.parameters()), + lr=self._cfg.learn.learning_rate, + ) + + self._gamma = self._cfg.learn.discount_factor + + self._target_model = copy.deepcopy(self._model) + + self._learn_model = model_wrap(self._model, wrapper_name='base') + self._learn_model.reset() + + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + loss_dict = {} + + if self._cuda: + data = to_device(data, self._device) + timesteps, obs, acts, rewards, rtg, masks, cond_id, cond_vals = data + obs, next_obs = obs[:-1], obs[1:] + acts = acts[:-1] + rewards = rewards[:-1] + conds = {cond_id: cond_vals} + + + self._learn_model.train() + pre_traj = torch.cat([acts, obs, rewards, next_obs], dim=1) + target = torch.cat([next_obs, rewards], dim=1) + traj = torch.cat([acts, obs], dim=1) + + batch_size = len(traj) + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=traj.device).long() + state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + loss_dict = {'state_loss': state_loss, 'reward_loss': reward_loss} + total_loss = state_loss + reward_loss + + self._pre_train_optimizer.zero() + total_loss.backward() + self._pre_train_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) + + diffuser_loss = self._learn_model.diffuser_loss(traj, conds, t) + self._plan_optimizer.zero() + diffuser_loss.backward() + self._plan_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) + + return loss_dict + + + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + if old_weight is None: + ma_params.data = up_weight + else: + old_weight * self.target_weight + (1 - self.target_weight) * up_weight + + def init_dataprocess_func(self, dataloader: torch.utils.data.Dataset): + self.dataloader = dataloader + + def _monitor_vars_learn(self) -> List[str]: + return [ + 'diffuse_loss', + 'reward_loss', + 'dynamic_loss', + 'max_return', + 'min_return', + 'mean_return', + 'a0_loss', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'plan_optimizer': self._plan_optimizer.state_dict(), + 'pre_train_optimizer': self._pre_train_optimizer.state_dict(), + } + + def _init_eval(self): + self._eval_model = model_wrap(self._target_model, wrapper_name='base') + self._eval_model.reset() + self.task_id = [0] * self.eval_batch_size + + + obs, acts, rewards, cond_ids, cond_vals = \ + self.dataloader.get_pretrain_data(self.task_id[0], self.warm_batch_size * self.eval_batch_size) + obs = to_device(obs, self._device) + acts = to_device(acts, self._device) + rewards = to_device(rewards, self._device) + cond_vals = to_device(cond_vals, self._device) + + obs, next_obs = obs[:-1], obs[1:] + acts = acts[:-1] + rewards = rewards[:-1] + pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=1) + target = torch.cat([next_obs, rewards], dim=1) + batch_size = len(pre_traj) + conds = {cond_ids: cond_vals} + + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + total_loss = state_loss + reward_loss + self._pre_train_optimizer.zero() + total_loss.backward() + self._pre_train_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + + self._eval_model.eval() + obs = [] + for i in range(self.eval_batch_size): + obs.append(self.dataloader.normalize(data, 'observations', self.task_id[i])) + + with torch.no_grad(): + obs = torch.tensor(obs) + if self._cuda: + obs = to_device(obs, self._device) + conditions = {0: obs} + action = self._eval_model.get_eval(conditions, self.plan_batch_size) + if self._cuda: + action = to_device(action, 'cpu') + for i in range(self.eval_batch_size): + action[i] = self.dataloader.unnormalize(action, 'actions', self.task_id[i]) + action = torch.tensor(action).to('cpu') + output = {'action': action} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + self.task_id[data_id] += 1 + + obs, acts, rewards, cond_ids, cond_vals = \ + self.dataloader.get_pretrain_data(self.task_id[data_id], self.warm_batch_size) + obs = to_device(obs, self._device) + acts = to_device(acts, self._device) + rewards = to_device(rewards, self._device) + cond_vals = to_device(cond_vals, self._device) + + obs, next_obs = obs[:-1], obs[1:] + acts = acts[:-1] + rewards = rewards[:-1] + pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=1) + target = torch.cat([next_obs, rewards], dim=1) + batch_size = len(pre_traj) + conds = {cond_ids: cond_vals} + + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + total_loss = state_loss + reward_loss + self._pre_train_optimizer.zero() + total_loss.backward() + self._pre_train_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) \ No newline at end of file diff --git a/ding/policy/prompt_dt.py b/ding/policy/prompt_dt.py new file mode 100755 index 0000000000..32dee63659 --- /dev/null +++ b/ding/policy/prompt_dt.py @@ -0,0 +1,192 @@ +from typing import List, Dict, Any, Tuple, Optional +from collections import namedtuple +import torch.nn.functional as F +import torch +import numpy as np +from ding.torch_utils import to_device +from ding.utils import POLICY_REGISTRY +from ding.utils.data import default_decollate +from ding.policy.dt import DTPolicy + +@POLICY_REGISTRY.register('promptdt') +class PDTPolicy(DTPolicy): + """ + Overview: + Policy class of Decision Transformer algorithm in discrete environments. + Paper link: https://arxiv.org/pdf/2206.13499. + """ + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + """ + Overview: + Policy forward function of learn mode (training policy and updating parameters). Forward means \ + that the policy inputs some training batch data from the offline dataset and then returns the output \ + result, including various training information such as loss, current learning rate. + Arguments: + - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \ + processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask. + Returns: + - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \ + recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \ + detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method. + + .. note:: + The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \ + For the data type that not supported, the main reason is that the corresponding model does not support it. \ + You can implement you own model rather than use the default model. For more information, please raise an \ + issue in GitHub repo and we will continue to follow up. + + """ + self._learn_model.train() + + prompt, timesteps, states, actions, rewards, returns_to_go, traj_mask = data + + # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), + # and we need a 3-dim tensor + if len(returns_to_go.shape) == 2: + returns_to_go = returns_to_go.unsqueeze(-1) + + if self._basic_discrete_env: + actions = actions.to(torch.long) + actions = actions.squeeze(-1) + action_target = torch.clone(actions).detach().to(self._device) + + if self._atari_env: + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1, prompt=prompt + ) + else: + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, prompt=prompt + ) + + if self._atari_env: + action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) + else: + traj_mask = traj_mask.view(-1, ) + + # only consider non padded elements + action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] + + if self._cfg.model.continuous: + action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] + action_loss = F.mse_loss(action_preds, action_target) + else: + action_target = action_target.view(-1)[traj_mask > 0] + action_loss = F.cross_entropy(action_preds, action_target) + + self._optimizer.zero_grad() + action_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p) + self._optimizer.step() + self._scheduler.step() + + return { + 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'], + 'action_loss': action_loss.detach().cpu().item(), + 'total_loss': action_loss.detach().cpu().item(), + } + + def get_dataloader(self, dataloader): + self.dataloader = dataloader + + def _init_eval(self) -> None: + self.task_id = [0] * self.eval_batch_size + super()._init_eval() + + def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: + prompt = [] + for i in range(self.eval_batch_size): + prompt.append(self.dataloader.get_prompt(is_test=True, id=self.task_id[i])) + + prompt = torch.tensor(prompt, device=self._device) + + data_id = list(data.keys()) + + self._eval_model.eval() + with torch.no_grad(): + if self._atari_env: + states = torch.zeros( + ( + self.eval_batch_size, + self.context_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) + else: + states = torch.zeros( + (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device + ) + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) + if not self._cfg.model.continuous: + actions = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device + ) + else: + actions = torch.zeros( + (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device + ) + rewards_to_go = torch.zeros( + (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device + ) + for i in data_id: + if self._atari_env: + self.states[i, self.t[i]] = data[i]['obs'].to(self._device) + else: + self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std + self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device) + self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] + + if self.t[i] <= self.context_len: + if self._atari_env: + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) + else: + timesteps[i] = self.timesteps[i, :self.context_len] + states[i] = self.states[i, :self.context_len] + actions[i] = self.actions[i, :self.context_len] + rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] + else: + if self._atari_env: + timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( + (1, 1), dtype=torch.int64 + ).to(self._device) + else: + timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + if self._basic_discrete_env: + actions = actions.squeeze(-1) + + _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go, prompt=prompt) + del timesteps, states, actions, rewards_to_go + + logits = act_preds[:, -1, :] + if not self._cfg.model.continuous: + if self._atari_env: + probs = F.softmax(logits, dim=-1) + act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) + for i in data_id: + act[i] = torch.multinomial(probs[i], num_samples=1) + else: + act = torch.argmax(logits, axis=1).unsqueeze(1) + for i in data_id: + self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t + self.t[i] += 1 + + if self._cuda: + act = to_device(act, 'cpu') + output = {'action': act} + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + + + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: + self.task_id[data_id] += 1 + super()._reset_eval(data_id) \ No newline at end of file diff --git a/ding/torch_utils/network/diffusion.py b/ding/torch_utils/network/diffusion.py index 8dfa9d3a14..da295bd00e 100755 --- a/ding/torch_utils/network/diffusion.py +++ b/ding/torch_utils/network/diffusion.py @@ -35,7 +35,7 @@ def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): return torch.tensor(betas_clipped, dtype=dtype) -def apply_conditioning(x, conditions, action_dim): +def apply_conditioning(x, conditions, action_dim, mask = None): """ Overview: add condition into x @@ -431,6 +431,7 @@ class TemporalValue(nn.Module): - time_dim (:obj:'): dim of time - dim_mults (:obj:'SequenceType'): mults of dim - kernel_size (:obj:'int'): kernel_size of conv1d + - returns_condition (:obj:'bool'): whether use an additionly condition """ def __init__( @@ -442,6 +443,7 @@ def __init__( out_dim: int = 1, kernel_size: int = 5, dim_mults: SequenceType = [1, 2, 4, 8], + returns_condition: bool = False, ) -> None: super().__init__() dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] @@ -454,6 +456,13 @@ def __init__( nn.Mish(), nn.Linear(dim * 4, dim), ) + if returns_condition: + self.returns_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + nn.Mish(), + nn.Linear(dim * 4, dim), + ) self.blocks = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): @@ -488,10 +497,15 @@ def __init__( nn.Linear(fc_dim // 2, out_dim), ) - def forward(self, x, cond, time, *args): + def forward(self, x, cond, time, returns=None, *args): # [batch, horizon, transition ] -> [batch, transition , horizon] x = x.transpose(1, 2) t = self.time_mlp(time) + + if returns: + returns_embed = self.returns_mlp(returns) + t = torch.cat([t, returns_embed], dim=-1) + for resnet, resnet2, downsample in self.blocks: x = resnet(x, t) x = resnet2(x, t) diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 23db0fcdf9..79f92b270c 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1101,6 +1101,279 @@ def __getitem__(self, idx, eps=1e-4): batch.update(self.get_conditions(observations)) return batch + +@DATASET_REGISTRY.register('meta_traj') +class MetaTraj(Dataset): + def __init__(self, cfg): + dataset_path = cfg.dataset.data_dir_prefix + self.rtg_scale = cfg.dataset.rtg_scale + self.context_len = cfg.dataset.context_len + self.env_type = cfg.dataset.env_type + self.no_state_normalize = cfg.policy.no_state_normalize + self.task_num = cfg.policy.task_num + self.state_dim = cfg.policy.model.obs_shape + self.act_dim = cfg.policy.model.act_shape + self.max_len = cfg.policy.max_len + self.max_ep_len = cfg.policy.max_ep_len + self.batch_size = cfg.policy.batch_size + self.stochastic_prompt = cfg.dataset.stochastic_prompt + self.need_prompt = cfg.dataset.need_prompt + self.task_id = 0 + self.test_id = cfg.dataset.test_id + self.cond = None + if 'cond' in cfg.dataset: + self.cond = cfg.dataset.cond + + try: + import h5py + import collections + except ImportError: + import sys + logging.warning("not found h5py package, please install it trough `pip install h5py ") + sys.exit(1) + + data_ = collections.defaultdict(list) + + file_paths = [dataset_path + i for i in range(1, self.task_num + 1)] + # train_env_dataset + self.traj = [] + self.state_means = [] + self.state_stds = [] + + # test_env_dataset + self.test_traj = [] + self.test_state_means = [] + self.test_state_stds = [] + + # for MetaDiffuser + if self.cond: + self.action_means = [] + self.action_stds = [] + self.test_action_means = [] + self.test_action_stds = [] + + # for prompt-DT + if self.need_prompt: + self.returns = [] + self.test_returns = [] + + id = 0 + for file_path in file_paths: + paths = [] + states = [] + if self.cond: + actions = [] + if self.need_prompt: + retruns = [] + total_reward = 0 + with h5py.File(file_path, 'r') as hf: + use_timeouts = False + if 'timeouts' in hf: + use_timeouts = True + N = hf['rewards'].shape[0] + for i in range(N): + done_bool = bool(hf['terminals'][i]) + if use_timeouts: + final_timestep = hf['timeouts'][i] + else: + final_timestep = (episode_step == 1000 - 1) + for k in ['observations', 'actions', 'rewards', 'terminals']: + data_[k].append(hf[k][i]) + if k == 'observations': + states.append[hf[k][i]] + if self.cond and k == 'actions': + actions.append(hf[k][i]) + if self.need_prompt and k == 'rewards': + total_reward += hf[k][i] + if done_bool or final_timestep: + episode_step = 0 + episode_data = {} + for k in data_: + episode_data[k] = np.array(data_[k]) + paths.append(episode_data) + data_ = collections.defaultdict(list) + + if self.need_prompt: + retruns.append(total_reward) + episode_step += 1 + states = np.array(states) + state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 + if self.cond: + action_mean, action_std = np.mean(actions, axis=0), np.std(actions, axis=0) + 1e-6 + + if id not in self.test_id: + self.traj.append(paths) + self.state_means.append(state_mean) + self.state_stds.append(state_std) + if self.cond: + self.action_means.append(action_mean) + self.action_stds.append(action_std) + if self.need_prompt: + self.returns.append(retruns) + else: + self.test_traj.append(paths) + self.test_state_means.append(state_mean) + self.test_state_stds.append(state_std) + if self.cond: + self.test_action_means.append(action_mean) + self.test_action_stds.append(action_std) + if self.need_prompt: + self.test_returns.append(retruns) + + id += 1 + + if self.need_prompt: + self.prompt_trajectories = [] + for i in range(len(self.traj)): + idx = np.argsort(self.returns) # lowest to highest + # set 10% highest traj as prompt + idx = idx[-(len(self.traj[i]) / 20) : ] + + self.prompt_trajectories.append(self.traj[i][idx]) + + self.test_prompt_trajectories = [] + for i in range(len(self.test_traj)): + idx = np.argsort(self.test_returns) + idx = idx[-(len(self.test_traj[i]) / 20) : ] + + self.test_prompt_trajectories.append(self.test_traj[i][idx]) + + + def get_prompt(self, sample_size=1, is_test=False, id=0): + if not is_test: + batch_inds = np.random.choice( + np.arange(len(self.prompt_trajectories[self.task_id])), + size=sample_size, + replace=True, + # p=p_sample, # reweights so we sample according to timesteps + ) + prompt_trajectories = self.prompt_trajectories[id] + sorted_inds = np.argsort(self.returns[id]) + else: + batch_inds = np.random.choice( + np.arange(len(self.test_prompt_trajectories[id])), + size=sample_size, + replace=True, + # p=p_sample, # reweights so we sample according to timesteps + ) + prompt_trajectories = self.test_prompt_trajectories[id] + sorted_inds = np.argsort(self.test_returns[id]) + + if self.stochastic_prompt: + traj = prompt_trajectories[int(batch_inds[i])] # random select traj + else: + traj = prompt_trajectories[int(sorted_inds[-i])] # select the best traj with highest rewards + # traj = prompt_trajectories[i] + si = max(0, traj['rewards'].shape[0] - self.max_len -1) # select the last traj with length max_len + + # get sequences from dataset + s = traj['observations'][si:si + self.max_len] + a = traj['actions'][si:si + self.max_len] + r = traj['rewards'][si:si + self.max_len] + + timesteps = np.arange(si, si + self.max_len) + rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s.shape[0] + 1]) + if rtg.shape[0] <= s.shape[0]: + rtg = np.concatenate([rtg, np.zeros((1, 1, 1))], axis=1) + + # padding and state + reward normalization + tlen = s.shape[0] + # if tlen !=args.K: + # print('tlen not equal to k') + s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s], axis=0) + if not self.no_state_normalize: + s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] + a = np.concatenate([np.ones((self.max_len - tlen, self.act_dim)) * -10., a], axis=0) + r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r], axis=0) + d = np.concatenate([np.ones((self.max_len - tlen)) * 2, d], axis=0) + rtg = np.concatenate([np.zeros((self.max_len - tlen, 1)), rtg], axis=0) / self.rtg_scale + timesteps = np.concatenate([np.zeros((self.max_len - tlen)), timesteps], axis=0) + mask = np.concatenate([np.zeros((self.max_len - tlen)), np.ones((tlen))], axis=0) + + return s, a, rtg, timesteps, mask + + # set task id + def set_task_id(self, id: int): + self.task_id = id + + def normalize(self, data: np.array, type: str, task_id: int): + if type == 'obs': + return (data - self.test_state_means[task_id]) / self.test_state_stds[task_id] + else: + return (data - self.test_action_means[task_id]) / self.test_action_stds[task_id] + + def unnormalize(self, data: np.array, type: str, task_id: int): + if type == 'obs': + return data * self.test_state_stds[task_id] + self.test_state_means[task_id] + else: + return data * self.test_action_stds[task_id] + self.test_action_means[task_id] + + # get warm start data + def get_pretrain_data(self, task_id: int, batch_size: int): + # get warm data + trajs = self.test_traj[task_id] + batch_idx = np.random.choice( + np.arange(len(trajs)), + size=batch_size, + ) + + traj = trajs[int(batch_idx)] + + si = np.random.randint(0, traj[0]['reward'].shape[0]) + traj = traj[:,si:si + self.max_len,:] + + s = traj['observations'] + a = traj['actions'] + r = traj['rewards'] + + tlen = s.shape[1] + s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s], axis=1) + if not self.no_state_normalize: + s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] + a = np.concatenate([np.ones((self.max_len - tlen, self.act_dim)) * -10., a], axis=1) + r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r], axis=1) + + s = torch.from_numpy(s).to(dtype=torch.float32) + a = torch.from_numpy(a).to(dtype=torch.float32) + r = torch.from_numpy(r).to(dtype=torch.float32) + + cond_id = 0 + cond_val = s[:,0] + return s, a, r, cond_id, cond_val + + def __getitem__(self, index): + traj = self.traj[self.task_id][index] + si = np.random.randint(0, traj['rewards'].shape[0]) + + s = traj['observations'][si:si + self.max_len] + a = traj['actions'][si:si + self.max_len] + r = traj['rewards'][si:si + self.max_len] + timesteps = np.arange(si, si + self.max_len) + rtg = discount_cumsum(traj['rewards'][si:], gamma=1.)[:s.shape[0] + 1] / self.rtg_scale + if rtg.shape[0] <= s.shape[0]: + rtg = np.concatenate([rtg, np.zeros((1, 1, 1))], axis=1) + + tlen = s.shape[0] + s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s], axis=0) + if not self.no_state_normalize: + s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] + a = np.concatenate([np.ones((self.max_len - tlen, self.act_dim)) * -10., a], axis=0) + r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r], axis=0) + d = np.concatenate([np.ones((self.max_len - tlen)) * 2, d], axis=0) + rtg = np.concatenate([np.zeros((self.max_len - tlen, 1)), rtg], axis=0) / self.rtg_scale + timesteps = np.concatenate([np.zeros((self.max_len - tlen)), timesteps], axis=0) + + mask = np.concatenate([np.zeros((self.max_len - tlen)), np.ones((tlen))], axis=0) + + if self.need_prompt: + prompt = self.get_prompt() + return prompt, timesteps, s, a, r, rtg, mask + elif self.cond: + cond_id = 0 + cond_val = s[0] + return timesteps, s, a, r, rtg, mask, cond_id, cond_val + else: + return timesteps, s, a, r, rtg, mask def hdf5_save(exp_data, expert_data_path): From 32ccf3f6b3a678802b1b4772b25fa75719d1c7da Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Tue, 2 Jan 2024 16:59:58 +0800 Subject: [PATCH 04/16] add metadiffuser --- ding/policy/meta_diffuser.py | 8 +- ding/utils/data/dataset.py | 7 +- .../interaction_serial_meta_evaluator.py | 75 +++++++++++ .../config/walker2d_metadiffuser_config.py | 116 ++++++++++++++++++ dizoo/meta_mujoco/entry/meta_entry.py | 21 ++++ dizoo/meta_mujoco/envs/meta_env.py | 80 ++++++++++++ 6 files changed, 300 insertions(+), 7 deletions(-) create mode 100644 ding/worker/collector/interaction_serial_meta_evaluator.py create mode 100644 dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py create mode 100644 dizoo/meta_mujoco/entry/meta_entry.py create mode 100755 dizoo/meta_mujoco/envs/meta_env.py diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py index 48202e618f..4babbe4701 100755 --- a/ding/policy/meta_diffuser.py +++ b/ding/policy/meta_diffuser.py @@ -39,9 +39,11 @@ class MDPolicy(Policy): # normalizer type normalizer='GaussianNormalizer', model=dict( + dim=32, + obs_dim=17, + action_dim=6, diffuser_cfg=dict( # the type of model - model='TemporalUnet', # config of model model_cfg=dict( # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim @@ -70,8 +72,7 @@ class MDPolicy(Policy): clip_denoised=False, action_weight=10, ), - value_model='ValueDiffusion', - value_model_cfg=dict( + reward_cfg=dict( # the type of model model='TemporalValue', # config of model @@ -96,6 +97,7 @@ class MDPolicy(Policy): clip_denoised=False, action_weight=1.0, ), + horizon=80, # guide_steps for p sample n_guide_steps=2, # scale of grad for p sample diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 79f92b270c..2cea23a1f2 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1108,14 +1108,13 @@ def __init__(self, cfg): dataset_path = cfg.dataset.data_dir_prefix self.rtg_scale = cfg.dataset.rtg_scale self.context_len = cfg.dataset.context_len - self.env_type = cfg.dataset.env_type self.no_state_normalize = cfg.policy.no_state_normalize self.task_num = cfg.policy.task_num - self.state_dim = cfg.policy.model.obs_shape - self.act_dim = cfg.policy.model.act_shape + self.state_dim = cfg.policy.model.obs_dim + self.act_dim = cfg.policy.model.act_dim self.max_len = cfg.policy.max_len self.max_ep_len = cfg.policy.max_ep_len - self.batch_size = cfg.policy.batch_size + self.batch_size = cfg.policy.learn.batch_size self.stochastic_prompt = cfg.dataset.stochastic_prompt self.need_prompt = cfg.dataset.need_prompt self.task_id = 0 diff --git a/ding/worker/collector/interaction_serial_meta_evaluator.py b/ding/worker/collector/interaction_serial_meta_evaluator.py new file mode 100644 index 0000000000..aa098cb1af --- /dev/null +++ b/ding/worker/collector/interaction_serial_meta_evaluator.py @@ -0,0 +1,75 @@ +from typing import Optional, Callable, Tuple, Dict, List +from collections import namedtuple, defaultdict +import numpy as np +import torch + +from ...envs import BaseEnvManager +from ...envs import BaseEnvManager + +from ding.envs import BaseEnvManager +from ding.torch_utils import to_tensor, to_ndarray, to_item +from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY +from ding.utils import get_world_size, get_rank, broadcast_object_list +from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor +from .interaction_serial_evaluator import InteractionSerialEvaluator + +class InteractionSerialMetaEvaluator(InteractionSerialEvaluator): + """ + Overview: + Interaction serial evaluator class, policy interacts with env. This class evaluator algorithm + with test environment list. + Interfaces: + __init__, reset, reset_policy, reset_env, close, should_eval, eval + Property: + env, policy + """ + config = dict( + # (int) Evaluate every "eval_freq" training iterations. + eval_freq=1000, + render=dict( + # Tensorboard video render is disabled by default. + render_freq=-1, + mode='train_iter', + ), + # (str) File path for visualize environment information. + figure_path=None, + # test env list + test_env_list=None, + ) + + def __init__( + self, + cfg: dict, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'evaluator', + ) -> None: + super()._init_eval(cfg, env, policy, tb_logger, exp_name, instance_name) + self.test_env_num = len(cfg.test_env_list) + + def eval( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + force_render: bool = False, + policy_kwargs: Optional[Dict] = {}, + ) -> Tuple[bool, Dict[str, List]]: + top_flags, episode_infos = [], defaultdict(list) + for i in range(self.test_env_num): + self._env.reset_task(self._cfg.test_env_list[i]) + top_flag, episode_info = super().eval(save_ckpt_fn, train_iter, envstep, n_episode, \ + force_render, policy_kwargs) + top_flags.append(top_flag) + for key, val in episode_info.items(): + if i == 0: + episode_infos[key] = [] + episode_infos[key].append(val) + + meta_infos = defaultdict(list) + for key, val in episode_infos.items(): + meta_infos[key] = episode_infos[key].mean() + return top_flags, meta_infos \ No newline at end of file diff --git a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py new file mode 100644 index 0000000000..efba4bf2e0 --- /dev/null +++ b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py @@ -0,0 +1,116 @@ +from easydict import EasyDict + +main_config = dict( + exp_name="walker2d_medium_expert_pd_seed0", + env=dict( + env_id='walker2d-medium-expert-v2', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + returns_scale=1.0, + termination_penalty=-100, + max_path_length=1000, + use_padding=True, + include_returns=True, + normed=False, + stop_value=8000, + horizon=32, + obs_dim=17, + action_dim=6, + ), + policy=dict( + cuda=True, + max_len=1000, + max_ep_len=1000, + task_num=40, + obs_dim=17, + action_dim=6, + model=dict( + diffuser_model_cfg=dict( + model='DiffusionUNet1d', + model_cfg=dict( + transition_dim=23, + dim=32, + dim_mults=[1, 2, 4, 8], + returns_condition=False, + kernel_size=5, + attention=False, + ), + horizon=32, + obs_dim=17, + action_dim=6, + n_timesteps=20, + predict_epsilon=False, + loss_discount=1, + action_weight=10, + ), + reward_cfg=dict( + model='TemporalValue', + model_cfg=dict( + horizon = 32, + transition_dim=23, + dim=32, + dim_mults=[1, 2, 4, 8], + kernel_size=5, + ), + horizon=32, + obs_dim=17, + action_dim=6, + n_timesteps=20, + predict_epsilon=True, + loss_discount=1, + ), + horizon=80, + n_guide_steps=2, + scale=0.1, + t_stopgrad=2, + scale_grad_by_std=True, + ), + normalizer='GaussianNormalizer', + learn=dict( + data_path=None, + train_epoch=60000, + gradient_accumulate_every=2, + batch_size=32, + learning_rate=2e-4, + discount_factor=0.99, + learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), + ), + collect=dict(data_type='diffuser_traj', ), + eval=dict( + evaluator=dict( + eval_freq=500, + test_env_list=[5,10,22,31,18,1,12,9,25,38], + ), + test_ret=0.9, + ), + dateset=dict( + data_dir_prefix=1, + rtg_scale=1, + context_len=1, + stochastic_prompt=False, + need_prompt=False, + test_id=[5,10,22,31,18,1,12,9,25,38], + cond=True + ), + other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type='d4rl', + import_names=['dizoo.d4rl.envs.d4rl_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='pd', + ), + replay_buffer=dict(type='naive', ), +) +create_config = EasyDict(create_config) +create_config = create_config \ No newline at end of file diff --git a/dizoo/meta_mujoco/entry/meta_entry.py b/dizoo/meta_mujoco/entry/meta_entry.py new file mode 100644 index 0000000000..d11844521b --- /dev/null +++ b/dizoo/meta_mujoco/entry/meta_entry.py @@ -0,0 +1,21 @@ +from ding.entry import serial_entry_meta_offline +from ding.config import read_config +from pathlib import Path + + +def train(args): + # launch from anywhere + config = Path(__file__).absolute().parent.parent / 'config' / args.config + config = read_config(str(config)) + config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) + serial_entry_meta_offline(config, seed=args.seed) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--seed', '-s', type=int, default=10) + parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py') + args = parser.parse_args() + train(args) \ No newline at end of file diff --git a/dizoo/meta_mujoco/envs/meta_env.py b/dizoo/meta_mujoco/envs/meta_env.py new file mode 100755 index 0000000000..577363dde1 --- /dev/null +++ b/dizoo/meta_mujoco/envs/meta_env.py @@ -0,0 +1,80 @@ +from typing import Any, Union, List +import copy +import gym +from easydict import EasyDict + +from CORRO.environments.make_env import make_env + +from ding.torch_utils import to_ndarray, to_list +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.envs.common.common_function import affine_transform +from ding.utils import ENV_REGISTRY + +@ENV_REGISTRY.register('meta') +class MujocoEnv(BaseEnv): + + def __init__(self, cfg: dict) -> None: + self._init_flag = False + self._use_act_scale = cfg.use_act_scale + self._cfg = cfg + + def reset(self) -> Any: + if not self._init_flag: + self._env = make_env(self._cfg.env_id, 1, seed=self._cfg.seed) + self._env.observation_space.dtype = np.float32 + self._observation_space = self._env.observation_space + self._action_space = self._env.action_space + self._reward_space = gym.spaces.Box( + low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 + ) + self._init_flag = True + obs = self._env.reset() + obs = to_ndarray(obs).astype('float32') + self._eval_episode_return = 0. + return obs + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: + action = to_ndarray(action) + if self._use_act_scale: + action_range = {'min': self.action_space.low[0], 'max': self.action_space.high[0], 'dtype': np.float32} + action = affine_transform(action, min_val=action_range['min'], max_val=action_range['max']) + obs, rew, done, info = self._env.step(action) + self._eval_episode_return += rew + obs = to_ndarray(obs).astype('float32') + rew = to_ndarray([rew]) + if done: + info['eval_episode_return'] = self._eval_episode_return + return BaseEnvTimestep(obs, rew, done, info) + + def __repr__(self) -> str: + return "DI-engine D4RL Env({})".format(self._cfg.env_id) + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_cfg = copy.deepcopy(cfg) + collector_env_num = collector_cfg.pop('collector_env_num', 1) + return [collector_cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_cfg = copy.deepcopy(cfg) + evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1) + evaluator_cfg.get('norm_reward', EasyDict(use_norm=False, )).use_norm = False + return [evaluator_cfg for _ in range(evaluator_env_num)] + + @property + def observation_space(self) -> gym.spaces.Space: + return self._observation_space + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + + @property + def reward_space(self) -> gym.spaces.Space: + return self._reward_space \ No newline at end of file From 94648d18b4d2bab8b9c40cafd5148b0654cf76b0 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Sat, 20 Jan 2024 16:05:29 +0800 Subject: [PATCH 05/16] change --- ding/entry/serial_entry_meta_offline.py | 114 ++++++++ ding/model/template/decision_transformer.py | 13 +- ding/model/template/diffusion.py | 117 ++++---- ding/policy/meta_diffuser.py | 216 +++++++++------ ding/policy/prompt_dt.py | 59 +++- ding/torch_utils/network/diffusion.py | 51 +++- ding/utils/data/dataset.py | 255 ++++++++++-------- .../interaction_serial_meta_evaluator.py | 172 ++++++++++-- .../config/walker2d_metadiffuser_config.py | 67 +++-- .../config/walker2d_promptdt_config.py | 107 ++++++++ dizoo/meta_mujoco/envs/meta_env.py | 18 +- 11 files changed, 873 insertions(+), 316 deletions(-) create mode 100755 ding/entry/serial_entry_meta_offline.py mode change 100644 => 100755 ding/worker/collector/interaction_serial_meta_evaluator.py mode change 100644 => 100755 dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py create mode 100755 dizoo/meta_mujoco/config/walker2d_promptdt_config.py diff --git a/ding/entry/serial_entry_meta_offline.py b/ding/entry/serial_entry_meta_offline.py new file mode 100755 index 0000000000..89bc43a6ed --- /dev/null +++ b/ding/entry/serial_entry_meta_offline.py @@ -0,0 +1,114 @@ +from typing import Union, Optional, List, Any, Tuple +import os +import torch +from functools import partial +from tensorboardX import SummaryWriter +from copy import deepcopy +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialMetaEvaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_world_size, get_rank +from ding.utils.data import create_dataset + +def serial_pipeline_meta_offline( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + env_setting: Optional[List[Any]] = None, + model: Optional[torch.nn.Module] = None, + max_train_iter: Optional[int] = int(1e10), +) -> 'Policy': # noqa + """ + Overview: + Serial pipeline entry. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ + ``BaseEnv`` subclass, collector env config, and evaluator env config. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + create_cfg.policy.type = create_cfg.policy.type + '_command' + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + cfg.env['seed'] = seed + + # Dataset + dataset = create_dataset(cfg) + + sampler, shuffle = None, True + if get_world_size() > 1: + sampler, shuffle = DistributedSampler(dataset), False + dataloader = DataLoader( + dataset, + # Dividing by get_world_size() here simply to make multigpu + # settings mathmatically equivalent to the singlegpu setting. + # If the training efficiency is the bottleneck, feel free to + # use the original batch size per gpu and increase learning rate + # correspondingly. + cfg.policy.learn.batch_size // get_world_size(), + # cfg.policy.learn.batch_size + shuffle=shuffle, + sampler=sampler, + collate_fn=lambda x: x, + pin_memory=cfg.policy.cuda, + ) + + # Env, policy + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) + + if hasattr(policy, 'set_statistic'): + # useful for setting action bounds for ibc + policy.set_statistic(dataset.statistics) + + if cfg.policy.need_init_dataprocess: + policy.init_dataprocess_func(dataset) + + if get_rank() == 0: + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + else: + tb_logger = None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + evaluator = InteractionSerialMetaEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + evaluator.init_params(dataset.params) + + learner.call_hook('before_run') + stop = False + + for epoch in range(cfg.policy.learn.train_epoch): + if get_world_size() > 1: + dataloader.sampler.set_epoch(epoch) + for i in range(cfg.policy.train_num): + dataset.set_task_id(i) + for train_data in dataloader: + learner.train(train_data) + + # Evaluate policy at most once per epoch. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + + if stop or learner.train_iter >= max_train_iter: + stop = True + break + + learner.call_hook('after_run') + print('final reward is: {}'.format(reward)) + return policy, stop \ No newline at end of file diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index fb01c9a0b0..4eea2ca01e 100755 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ding.utils import SequenceType +from ding.utils import SequenceType, MODEL_REGISTRY class MaskedCausalAttention(nn.Module): @@ -156,7 +156,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # x = x + self.mlp(self.ln2(x)) return x - +@MODEL_REGISTRY.register('dt') class DecisionTransformer(nn.Module): """ Overview: @@ -303,7 +303,9 @@ def forward( # time embeddings are treated similar to positional embeddings state_embeddings = self.embed_state(states) + time_embeddings action_embeddings = self.embed_action(actions) + time_embeddings - returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings + returns_embeddings = self.embed_rtg(returns_to_go) + returns_embeddings += time_embeddings + # stack rtg, states and actions and reshape sequence as # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) @@ -326,12 +328,15 @@ def forward( prompt_state_embeddings = prompt_state_embeddings + prompt_time_embeddings prompt_action_embeddings = prompt_action_embeddings + prompt_time_embeddings prompt_returns_embeddings = prompt_returns_embeddings + prompt_time_embeddings + prompt_stacked_inputs = torch.stack( + (prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 + ).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) prompt_stacked_attention_mask = torch.stack( (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length) if prompt_stacked_inputs.shape[1] == 3 * T: # if only smaple one prompt - prompt_stacked_inputs = prompt_stacked_inputs.reshape(1, -1, self.hidden_size) + prompt_stacked_inputs = prompt_stacked_inputs.reshape(1, -1, self.h_dim) prompt_stacked_attention_mask = prompt_stacked_attention_mask.reshape(1, -1) h = torch.cat((prompt_stacked_inputs.repeat(B, 1, 1), h), dim=1) stacked_attention_mask = torch.cat((prompt_stacked_attention_mask.repeat(B, 1), stacked_attention_mask), dim=1) diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index 934e3d7c73..cc6479f79c 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType from ding.torch_utils.network.diffusion import extract, cosine_beta_schedule, apply_conditioning, \ - DiffusionUNet1d, TemporalValue + DiffusionUNet1d, TemporalValue, Mish Sample = namedtuple('Sample', 'trajectories values chains') @@ -26,10 +26,16 @@ def default_sample_fn(model, x, cond, t): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, values -def get_guide_output(guide, x, cond, t, returns=None): +def get_guide_output(guide, x, cond, t, returns=None, is_dynamic=False): x.requires_grad_() - if returns: - y = guide(x, cond, t, returns).squeeze(dim=-1) + if returns is not None: + if not is_dynamic: + y = guide(x, cond, t, returns).squeeze(dim=-1) + else: + returns = returns.unsqueeze(1).repeat_interleave(x.shape[1],dim=1) + input = torch.cat([x, returns], dim=-1) + input = input.reshape(-1, input.shape[-1]) + y = guide(input) else: y = guide(x, cond, t).squeeze(dim=-1) grad = torch.autograd.grad([y.sum()], [x])[0] @@ -94,7 +100,7 @@ def free_guidance_sample( for _ in range(n_guide_steps): with torch.enable_grad(): y1, grad1 = get_guide_output(guide1, x, cond, t, returns) # get reward - y2, grad2 = get_guide_output(guide2, x, cond, t) # get state + y2, grad2 = get_guide_output(guide2, x, cond, t, returns, is_dynamic=True) # get state grad = grad1 + scale * grad2 if scale_grad_by_std: @@ -102,7 +108,7 @@ def free_guidance_sample( grad[t < t_stopgrad] = 0 - if returns: + if returns is not None: # epsilon could be epsilon or x0 itself epsilon_cond = model.model(x, cond, t, returns, use_dropout=False) epsilon_uncond = model.model(x, cond, t, returns, force_dropout=True) @@ -115,7 +121,7 @@ def free_guidance_sample( noise = torch.randn_like(x) noise[t == 0] = 0 - return model_mean + model_std * noise, + return model_mean + model_std * noise class GaussianDiffusion(nn.Module): """ @@ -614,7 +620,7 @@ def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusio batch_size = shape[0] x = 0.5 * torch.randn(shape, device=device) # In this model, init state must be given by the env and without noise. - x = apply_conditioning(x, cond, 0) + x = apply_conditioning(x, cond, self.action_dim) if return_diffusion: diffusion = [x] @@ -622,7 +628,7 @@ def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusio for i in reversed(range(0, self.n_timesteps)): timesteps = torch.full((batch_size, ), i, device=device, dtype=torch.long) x = self.p_sample(x, cond, timesteps, returns) - x = apply_conditioning(x, cond, 0) + x = apply_conditioning(x, cond, self.action_dim) if return_diffusion: diffusion.append(x) @@ -670,12 +676,12 @@ def p_losses(self, x_start, cond, t, returns=None): noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_noisy = apply_conditioning(x_noisy, cond, 0) + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) x_recon = self.model(x_noisy, cond, t, returns) if not self.predict_epsilon: - x_recon = apply_conditioning(x_recon, cond, 0) + x_recon = apply_conditioning(x_recon, cond, self.action_dim) assert noise.shape == x_recon.shape @@ -693,6 +699,20 @@ def forward(self, cond, *args, **kwargs): class GuidenceFreeDifffuser(GaussianInvDynDiffusion): + def get_loss_weights(self, discount: int): + self.action_weight = 1 + dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) + + # decay loss with trajectory timestep: discount**t + discounts = discount ** torch.arange(self.horizon, dtype=torch.float) + discounts = discounts / discounts.mean() + loss_weights = torch.einsum('h,t->ht', discounts, dim_weights) + # Cause things are conditioned on t=0 + if self.predict_epsilon: + loss_weights[0, :] = 0 + + return loss_weights + def p_mean_variance(self, x, cond, t, epsilon): x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) @@ -724,24 +744,25 @@ def conditional_sample(self, cond, horizon=None, **sample_kwargs): device = self.betas.device batch_size = len(cond[0]) horizon = horizon or self.horizon - shape = (batch_size, horizon, self.obs_dim) + shape = (batch_size, horizon, self.obs_dim + self.action_dim) return self.p_sample_loop(shape, cond, **sample_kwargs) def p_losses(self, x_start, cond, t, returns=None): noise = torch.randn_like(x_start) + batch_size = len(cond[0]) - mask_rand = torch.rand([batch_size]) - mask = torch.bernoulli(mask_rand, 0.7) + mask_rand = torch.rand((batch_size,1)) + mask = torch.bernoulli(mask_rand, 0.7).to(returns.device) returns = returns * mask x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_noisy = apply_conditioning(x_noisy, cond, 0) + x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) x_recon = self.model(x_noisy, cond, t, returns) if not self.predict_epsilon: - x_recon = apply_conditioning(x_recon, cond, 0) + x_recon = apply_conditioning(x_recon, cond, self.action_dim) assert noise.shape == x_recon.shape @@ -765,11 +786,12 @@ def __init__( obs_dim: Union[int, SequenceType], action_dim: Union[int, SequenceType], reward_cfg: dict, - diffuser_cfg: dict, + diffuser_model_cfg: dict, horizon: int, **sample_kwargs, ): - + super().__init__() + self.obs_dim = obs_dim self.action_dim = action_dim self.horizon = horizon @@ -777,13 +799,14 @@ def __init__( self.embed = nn.Sequential( nn.Linear((obs_dim * 2 + action_dim + 1) * horizon, dim * 4), - nn.Mish(), - nn.Linear(dim * 4, dim * 4), - nn.Mish(), + #nn.Mish(), + Mish(), nn.Linear(dim * 4, dim * 4), - nn.Mish(), + #nn.Mish(), + Mish(), nn.Linear(dim * 4, dim * 4), - nn.Mish(), + #nn.Mish(), + Mish(), nn.Linear(dim * 4, dim) ) @@ -794,52 +817,50 @@ def __init__( nn.ReLU(), nn.Linear(200, 200), nn.ReLU(), - nn.Linear(200, 200), - nn.ReLU(), - nn.Linear(200, 200), - nn.ReLU(), - nn.Linear(200, 200), - nn.ReLU(), nn.Linear(200, obs_dim), ) - self.diffuser = GuidenceFreeDifffuser(**diffuser_cfg) + self.diffuser = GuidenceFreeDifffuser(**diffuser_model_cfg) - def diffuser_loss(self, x_start, cond, t): - return self.diffuser.p_losses(x_start, cond, t) + def get_task_id(self, traj): + input_emb = traj.reshape(traj.shape[0], -1) + task_idx = self.embed(input_emb) + return task_idx + + def diffuser_loss(self, x_start, cond, t, returns=None): + return self.diffuser.p_losses(x_start, cond, t, returns) def pre_train_loss(self, traj, target, t, cond): input_emb = traj.reshape(target.shape[0], -1) task_idx = self.embed(input_emb) - states = traj[:,:, self.action_dim:self.action_dim + self.obs_dim] + states = traj[:, :, self.action_dim:self.action_dim + self.obs_dim] actions = traj[:, :, :self.action_dim] input = torch.cat([actions, states], dim=-1) - target_reward = target[:,-1] + target_reward = target[:, :, -1] - target_next_state = target[:, :-1] - task_idxs = torch.full(states.shape[:-1], task_idx, device=task_idx.device, dtype=torch.long) + target_next_state = target[:, :, :self.obs_dim].reshape(-1, self.obs_dim) - reward_loss, reward_log = self.reward_model.p_losses(input, cond, target_reward, t, task_idxs) + reward_loss, reward_log = self.reward_model.p_losses(input, cond, target_reward, t, task_idx) - - n = states.shape[1] + task_idxs = task_idx.unsqueeze(1).repeat_interleave(self.horizon,dim=1) + + input = torch.cat([input, task_idxs], dim=-1) + input = input.reshape(-1, input.shape[-1]) - state_loss = 0 - for i in range(n): - next_state = self.dynamic_model(input) - state_loss += F.mse_loss(next_state, target_next_state, reduction='none').mean() - state_loss /= n - return state_loss, reward_loss + next_state = self.dynamic_model(input) + state_loss = F.mse_loss(next_state, target_next_state, reduction='none').mean() + + return state_loss, reward_loss, reward_log def get_eval(self, cond, id, batch_size = 1): if batch_size > 1: cond = self.repeat_cond(cond, batch_size) - + id = torch.stack(id, dim=0) samples = self.diffuser(cond, returns=id, sample_fn=free_guidance_sample, plan_size=batch_size, - guide1=self.reward_model, guide2=self.dynamic_model **self.sample_kwargs) - return samples[:, 0, :,self.action_dim] + guide1=self.reward_model, guide2=self.dynamic_model, **self.sample_kwargs) + return samples[:, 0, :self.action_dim] def repeat_cond(self, cond, batch_size): for k, v in cond.items(): diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py index 4babbe4701..af01363b64 100755 --- a/ding/policy/meta_diffuser.py +++ b/ding/policy/meta_diffuser.py @@ -39,10 +39,10 @@ class MDPolicy(Policy): # normalizer type normalizer='GaussianNormalizer', model=dict( - dim=32, + dim=64, obs_dim=17, action_dim=6, - diffuser_cfg=dict( + diffuser_model_cfg=dict( # the type of model # config of model model_cfg=dict( @@ -70,7 +70,6 @@ class MDPolicy(Policy): loss_discount=1.0, # whether clip denoise clip_denoised=False, - action_weight=10, ), reward_cfg=dict( # the type of model @@ -138,8 +137,6 @@ class MDPolicy(Policy): gradient_accumulate_every=2, # train_epoch = train_epoch * gradient_accumulate_every train_epoch=60000, - # batch_size of every env when eval - plan_batch_size=64, # step start update target model and frequence step_start_update_target=2000, @@ -157,7 +154,7 @@ class MDPolicy(Policy): ) def default_model(self) -> Tuple[str, List[str]]: - return 'md', ['ding.model.template.diffusion'] + return 'metadiffuser', ['ding.model.template.diffusion'] def _init_learn(self) -> None: r""" @@ -172,7 +169,6 @@ def _init_learn(self) -> None: self.obs_dim = self._cfg.model.diffuser_model_cfg.obs_dim self.n_timesteps = self._cfg.model.diffuser_model_cfg.n_timesteps self.gradient_accumulate_every = self._cfg.learn.gradient_accumulate_every - self.plan_batch_size = self._cfg.learn.plan_batch_size self.gradient_steps = 1 self.update_target_freq = self._cfg.learn.update_target_freq self.step_start_update_target = self._cfg.learn.step_start_update_target @@ -182,6 +178,9 @@ def _init_learn(self) -> None: self.include_returns = self._cfg.learn.include_returns self.eval_batch_size = self._cfg.learn.eval_batch_size self.warm_batch_size = self._cfg.learn.warm_batch_size + self.test_num = self._cfg.learn.test_num + self.have_train = False + self._forward_learn_cnt = 0 self._plan_optimizer = Adam( self._model.diffuser.model.parameters(), @@ -202,38 +201,69 @@ def _init_learn(self) -> None: self._learn_model.reset() def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: + self.have_train = True loss_dict = {} if self._cuda: data = to_device(data, self._device) - timesteps, obs, acts, rewards, rtg, masks, cond_id, cond_vals = data - obs, next_obs = obs[:-1], obs[1:] - acts = acts[:-1] - rewards = rewards[:-1] - conds = {cond_id: cond_vals} + + obs, acts, rewards, cond_ids, cond_vals = [], [], [], [], [] + for d in data: + timesteps, ob, act, reward, rtg, masks, cond_id, cond_val = d + obs.append(ob) + acts.append(act) + rewards.append(reward) + cond_ids.append(cond_id) + cond_vals.append(cond_val) + + obs = torch.stack(obs, dim=0) + acts = torch.stack(acts, dim=0) + rewards = torch.stack(rewards, dim=0) + cond_vals = torch.stack(cond_vals, dim=0) + + obs, next_obs = obs[:,:-1], obs[:,1:] + acts = acts[:,:-1] + rewards = rewards[:,:-1] + conds = {cond_ids[0]: cond_vals} self._learn_model.train() - pre_traj = torch.cat([acts, obs, rewards, next_obs], dim=1) - target = torch.cat([next_obs, rewards], dim=1) - traj = torch.cat([acts, obs], dim=1) + pre_traj = torch.cat([acts, obs, rewards, next_obs], dim=-1).to(self._device) + target = torch.cat([next_obs, rewards], dim=-1).to(self._device) + traj = torch.cat([acts, obs], dim=-1).to(self._device) batch_size = len(traj) t = torch.randint(0, self.n_timesteps, (batch_size, ), device=traj.device).long() - state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) - loss_dict = {'state_loss': state_loss, 'reward_loss': reward_loss} - total_loss = state_loss + reward_loss - - self._pre_train_optimizer.zero() + state_loss, reward_loss, reward_log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + loss_dict = {'dynamic_loss': state_loss, 'reward_loss': reward_loss} + total_loss = (state_loss + reward_loss) / self.gradient_accumulate_every total_loss.backward() - self._pre_train_optimizer.step() - self.update_model_average(self._target_model, self._learn_model) + + if self.gradient_steps >= self.gradient_accumulate_every: + self._pre_train_optimizer.step() + self._pre_train_optimizer.zero_grad() + + task_id = self._learn_model.get_task_id(pre_traj) - diffuser_loss = self._learn_model.diffuser_loss(traj, conds, t) - self._plan_optimizer.zero() + diffuser_loss, a0_loss = self._learn_model.diffuser_loss(traj, conds, t, task_id) + loss_dict['diffuser_loss'] = diffuser_loss + loss_dict['a0_loss'] = a0_loss + diffuser_loss = diffuser_loss / self.gradient_accumulate_every diffuser_loss.backward() - self._plan_optimizer.step() - self.update_model_average(self._target_model, self._learn_model) + + if self.gradient_steps >= self.gradient_accumulate_every: + self._plan_optimizer.step() + self._plan_optimizer.zero_grad() + self.gradient_steps = 1 + else: + self.gradient_steps += 1 + + self._forward_learn_cnt += 1 + if self._forward_learn_cnt % self.update_target_freq == 0: + if self._forward_learn_cnt < self.step_start_update_target: + self._target_model.load_state_dict(self._model.state_dict()) + else: + self.update_model_average(self._target_model, self._learn_model) return loss_dict @@ -252,12 +282,9 @@ def init_dataprocess_func(self, dataloader: torch.utils.data.Dataset): def _monitor_vars_learn(self) -> List[str]: return [ - 'diffuse_loss', + 'diffuser_loss', 'reward_loss', 'dynamic_loss', - 'max_return', - 'min_return', - 'mean_return', 'a0_loss', ] @@ -272,31 +299,33 @@ def _state_dict_learn(self) -> Dict[str, Any]: def _init_eval(self): self._eval_model = model_wrap(self._target_model, wrapper_name='base') self._eval_model.reset() - self.task_id = [0] * self.eval_batch_size + self.task_id = None + self.test_task_id = [[] for _ in range(self.eval_batch_size)] + # self.task_id = [0] * self.eval_batch_size - obs, acts, rewards, cond_ids, cond_vals = \ - self.dataloader.get_pretrain_data(self.task_id[0], self.warm_batch_size * self.eval_batch_size) - obs = to_device(obs, self._device) - acts = to_device(acts, self._device) - rewards = to_device(rewards, self._device) - cond_vals = to_device(cond_vals, self._device) + # obs, acts, rewards, cond_ids, cond_vals = \ + # self.dataloader.get_pretrain_data(self.task_id[0], self.warm_batch_size * self.eval_batch_size) + # obs = to_device(obs, self._device) + # acts = to_device(acts, self._device) + # rewards = to_device(rewards, self._device) + # cond_vals = to_device(cond_vals, self._device) - obs, next_obs = obs[:-1], obs[1:] - acts = acts[:-1] - rewards = rewards[:-1] - pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=1) - target = torch.cat([next_obs, rewards], dim=1) - batch_size = len(pre_traj) - conds = {cond_ids: cond_vals} - - t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() - state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) - total_loss = state_loss + reward_loss - self._pre_train_optimizer.zero() - total_loss.backward() - self._pre_train_optimizer.step() - self.update_model_average(self._target_model, self._learn_model) + # obs, next_obs = obs[:-1], obs[1:] + # acts = acts[:-1] + # rewards = rewards[:-1] + # pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=1) + # target = torch.cat([next_obs, rewards], dim=1) + # batch_size = len(pre_traj) + # conds = {cond_ids: cond_vals} + + # t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + # state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + # total_loss = state_loss + reward_loss + # self._pre_train_optimizer.zero() + # total_loss.backward() + # self._pre_train_optimizer.step() + # self.update_model_average(self._target_model, self._learn_model) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: data_id = list(data.keys()) @@ -305,45 +334,70 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self._eval_model.eval() obs = [] for i in range(self.eval_batch_size): - obs.append(self.dataloader.normalize(data, 'observations', self.task_id[i])) - + if not self._cfg.no_state_normalize: + obs.append(self.dataloader.normalize(data[i], 'obs', self.task_id[i])) + with torch.no_grad(): - obs = torch.tensor(obs) + obs = torch.stack(obs, dim=0) if self._cuda: obs = to_device(obs, self._device) conditions = {0: obs} - action = self._eval_model.get_eval(conditions, self.plan_batch_size) + action = self._eval_model.get_eval(conditions, self.test_task_id) if self._cuda: action = to_device(action, 'cpu') for i in range(self.eval_batch_size): - action[i] = self.dataloader.unnormalize(action, 'actions', self.task_id[i]) + if not self._cfg.no_action_normalize: + action[i] = self.dataloader.unnormalize(action[i], 'actions', self.task_id[i]) action = torch.tensor(action).to('cpu') output = {'action': action} output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: - self.task_id[data_id] += 1 - - obs, acts, rewards, cond_ids, cond_vals = \ - self.dataloader.get_pretrain_data(self.task_id[data_id], self.warm_batch_size) - obs = to_device(obs, self._device) - acts = to_device(acts, self._device) - rewards = to_device(rewards, self._device) - cond_vals = to_device(cond_vals, self._device) - - obs, next_obs = obs[:-1], obs[1:] - acts = acts[:-1] - rewards = rewards[:-1] - pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=1) - target = torch.cat([next_obs, rewards], dim=1) - batch_size = len(pre_traj) - conds = {cond_ids: cond_vals} - - t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() - state_loss, reward_loss = self._learn_model.pre_train_loss(pre_traj, target, t, conds) - total_loss = state_loss + reward_loss - self._pre_train_optimizer.zero() - total_loss.backward() - self._pre_train_optimizer.step() - self.update_model_average(self._target_model, self._learn_model) \ No newline at end of file + if self.have_train: + if data_id is None: + data_id = list(range(self.eval_batch_size)) + if self.task_id is not None: + for id in data_id: + self.task_id[id] = (self.task_id[id] + 1) % self.test_num + else: + self.task_id = [0] * self.eval_batch_size + + for id in data_id: + obs, acts, rewards, cond_ids, cond_vals = \ + self.dataloader.get_pretrain_data(self.task_id[id], self.warm_batch_size) + obs = to_device(obs, self._device) + acts = to_device(acts, self._device) + rewards = to_device(rewards, self._device) + cond_vals = to_device(cond_vals, self._device) + + obs, next_obs = obs[:, :-1], obs[:, 1:] + acts = acts[:, :-1] + rewards = rewards[:, :-1] + + pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=-1) + target = torch.cat([next_obs, rewards], dim=-1) + batch_size = len(pre_traj) + conds = {cond_ids: cond_vals} + + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + total_loss = state_loss + reward_loss + self._pre_train_optimizer.zero_grad() + total_loss.backward() + self._pre_train_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) + + self.test_task_id[id] = self._target_model.get_task_id(pre_traj)[0] + + def _init_collect(self) -> None: + pass + + def _forward_collect(self, data: dict, **kwargs) -> dict: + pass + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + pass + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + pass diff --git a/ding/policy/prompt_dt.py b/ding/policy/prompt_dt.py index 32dee63659..91dd4dd8d6 100755 --- a/ding/policy/prompt_dt.py +++ b/ding/policy/prompt_dt.py @@ -7,6 +7,7 @@ from ding.utils import POLICY_REGISTRY from ding.utils.data import default_decollate from ding.policy.dt import DTPolicy +from ding.model import model_wrap @POLICY_REGISTRY.register('promptdt') class PDTPolicy(DTPolicy): @@ -15,6 +16,9 @@ class PDTPolicy(DTPolicy): Policy class of Decision Transformer algorithm in discrete environments. Paper link: https://arxiv.org/pdf/2206.13499. """ + def default_model(self) -> Tuple[str, List[str]]: + return 'dt', ['ding.model.template.decision_transformer'] + def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: """ Overview: @@ -37,8 +41,41 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: """ self._learn_model.train() + self.have_train = True - prompt, timesteps, states, actions, rewards, returns_to_go, traj_mask = data + if self._cuda: + data = to_device(data, self._device) + + p_s, p_a, p_rtg, p_t, p_mask, timesteps, states, actions, rewards, returns_to_go, \ + traj_mask = [], [], [], [], [], [], [], [], [], [], [] + + for d in data: + p, timestep, s, a, r, rtg, mask = d + timesteps.append(timestep) + states.append(s) + actions.append(a) + rewards.append(r) + returns_to_go.append(rtg) + traj_mask.append(mask) + ps, pa, prtg, pt, pm = p + p_s.append(ps) + p_a.append(pa) + p_rtg.append(prtg) + p_mask.append(pm) + p_t.append(pt) + + timesteps = torch.stack(timesteps, dim=0) + states = torch.stack(states, dim=0) + actions = torch.stack(actions, dim=0) + rewards = torch.stack(rewards, dim=0) + returns_to_go = torch.stack(returns_to_go, dim=0) + traj_mask = torch.stack(traj_mask, dim=0) + p_s = torch.stack(p_s, dim=0) + p_a = torch.stack(p_a, dim=0) + p_rtg = torch.stack(p_rtg, dim=0) + p_mask = torch.stack(p_mask, dim=0) + p_t = torch.stack(p_t, dim=0) + prompt = (p_s, p_a, p_rtg, p_t, p_mask) # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), # and we need a 3-dim tensor @@ -88,12 +125,16 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: 'total_loss': action_loss.detach().cpu().item(), } - def get_dataloader(self, dataloader): + def init_dataprocess_func(self, dataloader): self.dataloader = dataloader def _init_eval(self) -> None: - self.task_id = [0] * self.eval_batch_size - super()._init_eval() + self.test_num = self._cfg.learn.test_num + self._eval_model = self._model + self.eval_batch_size = self._cfg.evaluator_env_num + self.task_id = None + self.test_task_id = [[] for _ in range(self.eval_batch_size)] + self.have_train =False def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: prompt = [] @@ -188,5 +229,11 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: - self.task_id[data_id] += 1 - super()._reset_eval(data_id) \ No newline at end of file + if self.have_train: + if data_id is None: + data_id = list(range(self.eval_batch_size)) + if self.task_id is not None: + for id in data_id: + self.task_id[id] = (self.task_id[id] + 1) % self.test_num + else: + self.task_id = [0] * self.eval_batch_size \ No newline at end of file diff --git a/ding/torch_utils/network/diffusion.py b/ding/torch_utils/network/diffusion.py index da295bd00e..674e8e5b76 100755 --- a/ding/torch_utils/network/diffusion.py +++ b/ding/torch_utils/network/diffusion.py @@ -44,6 +44,13 @@ def apply_conditioning(x, conditions, action_dim, mask = None): x[:, t, action_dim:] = val.clone() return x +class Mish(nn.Module): + def __init__(self): + super().__init__() + + def forward(self,x): + x = x * (torch.tanh(F.softplus(x))) + return x class DiffusionConv1d(nn.Module): @@ -197,7 +204,7 @@ def __init__( ) -> None: super().__init__() if mish: - act = nn.Mish() + act = Mish()#nn.Mish() else: act = nn.SiLU() self.blocks = nn.ModuleList( @@ -214,7 +221,8 @@ def __init__( if in_channels != out_channels else nn.Identity() def forward(self, x, t): - out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) + out = self.blocks[0](x) + out += self.time_mlp(t).unsqueeze(-1) out = self.blocks[1](out) return out + self.residual_conv(x) @@ -254,7 +262,7 @@ def __init__( act = nn.SiLU() else: mish = True - act = nn.Mish() + act = Mish()#nn.Mish() self.time_dim = dim self.returns_dim = dim @@ -272,8 +280,6 @@ def __init__( if self.returns_condition: self.returns_mlp = nn.Sequential( - nn.Linear(1, dim), - act, nn.Linear(dim, dim * 4), act, nn.Linear(dim * 4, dim), @@ -444,6 +450,7 @@ def __init__( kernel_size: int = 5, dim_mults: SequenceType = [1, 2, 4, 8], returns_condition: bool = False, + no_need_ret_sin: bool =False, ) -> None: super().__init__() dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] @@ -453,16 +460,27 @@ def __init__( self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), - nn.Mish(), + #nn.Mish(), + Mish(), nn.Linear(dim * 4, dim), ) if returns_condition: - self.returns_mlp = nn.Sequential( - SinusoidalPosEmb(dim), - nn.Linear(dim, dim * 4), - nn.Mish(), - nn.Linear(dim * 4, dim), - ) + time_dim += time_dim + if not no_need_ret_sin: + self.returns_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, dim * 4), + #nn.Mish(), + Mish(), + nn.Linear(dim * 4, dim), + ) + else: + self.returns_mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + #nn.Mish(), + Mish(), + nn.Linear(dim * 4, dim), + ) self.blocks = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): @@ -493,7 +511,8 @@ def __init__( fc_dim = mid_dim_3 * max(horizon, 1) self.final_block = nn.Sequential( nn.Linear(fc_dim + time_dim, fc_dim // 2), - nn.Mish(), + #nn.Mish(), + Mish(), nn.Linear(fc_dim // 2, out_dim), ) @@ -502,14 +521,18 @@ def forward(self, x, cond, time, returns=None, *args): x = x.transpose(1, 2) t = self.time_mlp(time) - if returns: + if returns is not None: returns_embed = self.returns_mlp(returns) t = torch.cat([t, returns_embed], dim=-1) for resnet, resnet2, downsample in self.blocks: + # print('x:',x) x = resnet(x, t) + # print('after res1 x:',x) x = resnet2(x, t) + # print('after res2 x:',x) x = downsample(x) + # print('after down x:',x) x = self.mid_block1(x, t) x = self.mid_down1(x) diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 2cea23a1f2..58c19a71f7 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1106,19 +1106,21 @@ def __getitem__(self, idx, eps=1e-4): class MetaTraj(Dataset): def __init__(self, cfg): dataset_path = cfg.dataset.data_dir_prefix + env_param_path = cfg.dataset.env_param_path self.rtg_scale = cfg.dataset.rtg_scale self.context_len = cfg.dataset.context_len self.no_state_normalize = cfg.policy.no_state_normalize + self.no_action_normalize = cfg.policy.no_action_normalize self.task_num = cfg.policy.task_num - self.state_dim = cfg.policy.model.obs_dim - self.act_dim = cfg.policy.model.act_dim + self.state_dim = cfg.policy.obs_dim + self.act_dim = cfg.policy.act_dim self.max_len = cfg.policy.max_len self.max_ep_len = cfg.policy.max_ep_len self.batch_size = cfg.policy.learn.batch_size self.stochastic_prompt = cfg.dataset.stochastic_prompt self.need_prompt = cfg.dataset.need_prompt - self.task_id = 0 self.test_id = cfg.dataset.test_id + self.need_next_obs = cfg.dataset.need_next_obs self.cond = None if 'cond' in cfg.dataset: self.cond = cfg.dataset.cond @@ -1130,10 +1132,10 @@ def __init__(self, cfg): import sys logging.warning("not found h5py package, please install it trough `pip install h5py ") sys.exit(1) - data_ = collections.defaultdict(list) - file_paths = [dataset_path + i for i in range(1, self.task_num + 1)] + file_paths = [dataset_path + '_{}_sub_task_0.hdf5'.format(i) for i in range(0, self.task_num)] + param_paths = [env_param_path + '{}.pkl'.format(i) for i in self.test_id] # train_env_dataset self.traj = [] self.state_means = [] @@ -1145,7 +1147,7 @@ def __init__(self, cfg): self.test_state_stds = [] # for MetaDiffuser - if self.cond: + if not self.no_action_normalize: self.action_means = [] self.action_stds = [] self.test_action_means = [] @@ -1155,88 +1157,71 @@ def __init__(self, cfg): if self.need_prompt: self.returns = [] self.test_returns = [] - - id = 0 + + id = 0 for file_path in file_paths: - paths = [] - states = [] - if self.cond: - actions = [] if self.need_prompt: - retruns = [] - total_reward = 0 + returns = [] with h5py.File(file_path, 'r') as hf: - use_timeouts = False - if 'timeouts' in hf: - use_timeouts = True N = hf['rewards'].shape[0] + path = [] for i in range(N): - done_bool = bool(hf['terminals'][i]) - if use_timeouts: - final_timestep = hf['timeouts'][i] - else: - final_timestep = (episode_step == 1000 - 1) - for k in ['observations', 'actions', 'rewards', 'terminals']: + for k in ['obs', 'actions', 'rewards', 'terminals', 'mask']: data_[k].append(hf[k][i]) - if k == 'observations': - states.append[hf[k][i]] - if self.cond and k == 'actions': - actions.append(hf[k][i]) - if self.need_prompt and k == 'rewards': - total_reward += hf[k][i] - if done_bool or final_timestep: - episode_step = 0 - episode_data = {} - for k in data_: - episode_data[k] = np.array(data_[k]) - paths.append(episode_data) - data_ = collections.defaultdict(list) - - if self.need_prompt: - retruns.append(total_reward) - episode_step += 1 - states = np.array(states) - state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 - if self.cond: - action_mean, action_std = np.mean(actions, axis=0), np.std(actions, axis=0) + 1e-6 - - if id not in self.test_id: - self.traj.append(paths) - self.state_means.append(state_mean) - self.state_stds.append(state_std) - if self.cond: - self.action_means.append(action_mean) - self.action_stds.append(action_std) - if self.need_prompt: - self.returns.append(retruns) - else: - self.test_traj.append(paths) - self.test_state_means.append(state_mean) - self.test_state_stds.append(state_std) - if self.cond: - self.test_action_means.append(action_mean) - self.test_action_stds.append(action_std) + path.append(data_) if self.need_prompt: - self.test_returns.append(retruns) - - id += 1 + returns.append(hf['returns'][0][i]) + data_ = collections.defaultdict(list) + + state_mean, state_std = hf['state_mean'][:], hf['state_std'][:] + if not self.no_action_normalize: + action_mean, action_std = hf['action_mean'][:], hf['action_std'][:] + + if id not in self.test_id: + self.traj.append(path) + self.state_means.append(state_mean) + self.state_stds.append(state_std) + if not self.no_action_normalize: + self.action_means.append(action_mean) + self.action_stds.append(action_std) + if self.need_prompt: + self.returns.append(returns) + else: + self.test_traj.append(path) + self.test_state_means.append(state_mean) + self.test_state_stds.append(state_std) + if not self.no_action_normalize: + self.test_action_means.append(action_mean) + self.test_action_stds.append(action_std) + if self.need_prompt: + self.test_returns.append(returns) + id += 1 + + self.params = [] + for file in param_paths: + with open(file, 'rb') as f: + self.params.append(pickle.load(f)[0]) if self.need_prompt: self.prompt_trajectories = [] for i in range(len(self.traj)): idx = np.argsort(self.returns) # lowest to highest # set 10% highest traj as prompt - idx = idx[-(len(self.traj[i]) / 20) : ] + idx = idx[-(len(self.traj[i]) // 20) : ] - self.prompt_trajectories.append(self.traj[i][idx]) + self.prompt_trajectories.append(np.array(self.traj[i])[idx]) self.test_prompt_trajectories = [] for i in range(len(self.test_traj)): idx = np.argsort(self.test_returns) - idx = idx[-(len(self.test_traj[i]) / 20) : ] + idx = idx[-(len(self.test_traj[i]) // 20) : ] - self.test_prompt_trajectories.append(self.test_traj[i][idx]) + self.test_prompt_trajectories.append(np.array(self.test_traj[i])[idx]) + + self.set_task_id(0) + def __len__(self): + return len(self.traj[self.task_id]) def get_prompt(self, sample_size=1, is_test=False, id=0): if not is_test: @@ -1259,35 +1244,39 @@ def get_prompt(self, sample_size=1, is_test=False, id=0): sorted_inds = np.argsort(self.test_returns[id]) if self.stochastic_prompt: - traj = prompt_trajectories[int(batch_inds[i])] # random select traj + traj = prompt_trajectories[batch_inds[sample_size]][0] # random select traj else: - traj = prompt_trajectories[int(sorted_inds[-i])] # select the best traj with highest rewards + traj = prompt_trajectories[sorted_inds[-sample_size]][0] # select the best traj with highest rewards # traj = prompt_trajectories[i] - si = max(0, traj['rewards'].shape[0] - self.max_len -1) # select the last traj with length max_len + si = max(0, traj['rewards'][0].shape[1] - self.max_len -1) # select the last traj with length max_len # get sequences from dataset - s = traj['observations'][si:si + self.max_len] - a = traj['actions'][si:si + self.max_len] - r = traj['rewards'][si:si + self.max_len] + + s = traj['obs'][0][si:si + self.max_len] + a = traj['actions'][0][si:si + self.max_len] + mask = traj['mask'][0][si:si + self.max_len] - timesteps = np.arange(si, si + self.max_len) - rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s.shape[0] + 1]) - if rtg.shape[0] <= s.shape[0]: - rtg = np.concatenate([rtg, np.zeros((1, 1, 1))], axis=1) + timesteps = np.arange(si, si + np.array(mask).sum()) + rtg = discount_cumsum(traj['rewards'][0][si:], gamma=1.)[:s.shape[0]] + if rtg.shape[0] < s.shape[0]: + rtg = np.concatenate([rtg, np.zeros(((s.shape[0] - rtg.shape[0]), 1))], axis=1) # padding and state + reward normalization - tlen = s.shape[0] # if tlen !=args.K: # print('tlen not equal to k') - s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s], axis=0) if not self.no_state_normalize: s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] - a = np.concatenate([np.ones((self.max_len - tlen, self.act_dim)) * -10., a], axis=0) - r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r], axis=0) - d = np.concatenate([np.ones((self.max_len - tlen)) * 2, d], axis=0) - rtg = np.concatenate([np.zeros((self.max_len - tlen, 1)), rtg], axis=0) / self.rtg_scale - timesteps = np.concatenate([np.zeros((self.max_len - tlen)), timesteps], axis=0) - mask = np.concatenate([np.zeros((self.max_len - tlen)), np.ones((tlen))], axis=0) + rtg = rtg/ self.rtg_scale + + t_len = int(np.array(mask).sum()) + + timesteps = np.concatenate([timesteps, np.zeros((self.max_len - t_len))], axis=0) + + s = torch.from_numpy(s).to(dtype=torch.float32) + a = torch.from_numpy(a).to(dtype=torch.float32) + rtg = torch.from_numpy(rtg).to(dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.long) + mask = torch.from_numpy(mask).to(dtype=torch.long) return s, a, rtg, timesteps, mask @@ -1316,21 +1305,31 @@ def get_pretrain_data(self, task_id: int, batch_size: int): size=batch_size, ) - traj = trajs[int(batch_idx)] - - si = np.random.randint(0, traj[0]['reward'].shape[0]) - traj = traj[:,si:si + self.max_len,:] + max_len = self.max_len + if self.need_next_obs: + max_len += 1 - s = traj['observations'] - a = traj['actions'] - r = traj['rewards'] - - tlen = s.shape[1] - s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s], axis=1) - if not self.no_state_normalize: - s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] - a = np.concatenate([np.ones((self.max_len - tlen, self.act_dim)) * -10., a], axis=1) - r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r], axis=1) + s, a, r = [], [], [] + + for idx in batch_idx: + traj = trajs[idx] + si = np.random.randint(0, len(traj['obs'][0]) - max_len) + + state = traj['obs'][0][si:si + max_len] + action = traj['actions'][0][si:si + max_len] + state = np.array(state).squeeze() + action = np.array(action).squeeze() + if not self.no_state_normalize: + state = (state - self.test_state_means[task_id]) / self.test_state_stds[task_id] + if not self.no_action_normalize: + action = (action - self.test_action_means[task_id]) / self.test_action_stds[task_id] + s.append(state) + a.append(action) + r.append(traj['rewards'][0][si:si + max_len]) + + s = np.array(s) + a = np.array(a) + r = np.array(r) s = torch.from_numpy(s).to(dtype=torch.float32) a = torch.from_numpy(a).to(dtype=torch.float32) @@ -1342,30 +1341,48 @@ def get_pretrain_data(self, task_id: int, batch_size: int): def __getitem__(self, index): traj = self.traj[self.task_id][index] - si = np.random.randint(0, traj['rewards'].shape[0]) - - s = traj['observations'][si:si + self.max_len] - a = traj['actions'][si:si + self.max_len] - r = traj['rewards'][si:si + self.max_len] - timesteps = np.arange(si, si + self.max_len) - rtg = discount_cumsum(traj['rewards'][si:], gamma=1.)[:s.shape[0] + 1] / self.rtg_scale - if rtg.shape[0] <= s.shape[0]: - rtg = np.concatenate([rtg, np.zeros((1, 1, 1))], axis=1) - - tlen = s.shape[0] - s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s], axis=0) + si = np.random.randint(0, len(traj['rewards'][0]) - self.max_len) + + max_len = self.max_len + if self.need_next_obs: + max_len += 1 + + s = traj['obs'][0][si:si + max_len] + a = traj['actions'][0][si:si + max_len] + r = traj['rewards'][0][si:si + max_len] + mask = np.array(traj['mask'][0][si:si + max_len]) + # mask = np.ones((s.shape[0])) + timesteps = np.arange(si, si + mask.sum()) + rtg = discount_cumsum(traj['rewards'][0][si:], gamma=1.)[:s.shape[0]] / self.rtg_scale + if rtg.shape[0] < s.shape[0]: + rtg = np.concatenate([rtg, np.zeros(((s.shape[0] - rtg.shape[0]), 1))], axis=1) + + if not self.no_state_normalize: s = (s - self.state_means[self.task_id]) / self.state_stds[self.task_id] - a = np.concatenate([np.ones((self.max_len - tlen, self.act_dim)) * -10., a], axis=0) - r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r], axis=0) - d = np.concatenate([np.ones((self.max_len - tlen)) * 2, d], axis=0) - rtg = np.concatenate([np.zeros((self.max_len - tlen, 1)), rtg], axis=0) / self.rtg_scale - timesteps = np.concatenate([np.zeros((self.max_len - tlen)), timesteps], axis=0) + if not self.no_action_normalize: + a = (a - self.action_means[self.task_id]) / self.action_stds[self.task_id] + + s = np.array(s) + a = np.array(a) + r = np.array(r) + + tlen = int(mask.sum()) - mask = np.concatenate([np.zeros((self.max_len - tlen)), np.ones((tlen))], axis=0) + s = torch.from_numpy(s).to(dtype=torch.float32) + a = torch.from_numpy(a).to(dtype=torch.float32) + r = torch.from_numpy(r).to(dtype=torch.float32) + rtg = rtg / self.rtg_scale + timesteps = np.concatenate([timesteps, np.zeros((max_len - tlen))], axis=0) + + + rtg = torch.from_numpy(rtg).to(dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(dtype=torch.long) + mask = torch.from_numpy(mask).to(dtype=torch.long) + if self.need_prompt: - prompt = self.get_prompt() + prompt = self.get_prompt(self.task_id) return prompt, timesteps, s, a, r, rtg, mask elif self.cond: cond_id = 0 diff --git a/ding/worker/collector/interaction_serial_meta_evaluator.py b/ding/worker/collector/interaction_serial_meta_evaluator.py old mode 100644 new mode 100755 index aa098cb1af..23269a1d50 --- a/ding/worker/collector/interaction_serial_meta_evaluator.py +++ b/ding/worker/collector/interaction_serial_meta_evaluator.py @@ -33,8 +33,8 @@ class InteractionSerialMetaEvaluator(InteractionSerialEvaluator): ), # (str) File path for visualize environment information. figure_path=None, - # test env list - test_env_list=None, + # test env num + test_env_num=10, ) def __init__( @@ -46,8 +46,12 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'evaluator', ) -> None: - super()._init_eval(cfg, env, policy, tb_logger, exp_name, instance_name) - self.test_env_num = len(cfg.test_env_list) + super().__init__(cfg, env, policy, tb_logger, exp_name, instance_name) + self.test_env_num = cfg.test_env_num + + def init_params(self, params): + self.params = params + self._env.set_all_goals(params) def eval( self, @@ -58,18 +62,156 @@ def eval( force_render: bool = False, policy_kwargs: Optional[Dict] = {}, ) -> Tuple[bool, Dict[str, List]]: - top_flags, episode_infos = [], defaultdict(list) + infos = defaultdict(list) for i in range(self.test_env_num): - self._env.reset_task(self._cfg.test_env_list[i]) - top_flag, episode_info = super().eval(save_ckpt_fn, train_iter, envstep, n_episode, \ - force_render, policy_kwargs) - top_flags.append(top_flag) - for key, val in episode_info.items(): + print('-----------------------------start task ', i) + self._env.reset_task(i) + info = self.sub_eval(save_ckpt_fn, train_iter, envstep, n_episode, \ + force_render, policy_kwargs, i) + for key, val in info.items(): if i == 0: - episode_infos[key] = [] - episode_infos[key].append(val) + info[key] = [] + infos[key].append(val) meta_infos = defaultdict(list) - for key, val in episode_infos.items(): - meta_infos[key] = episode_infos[key].mean() - return top_flags, meta_infos \ No newline at end of file + for key, val in info.items(): + meta_infos[key] = np.array(val).mean() + + episode_return = meta_infos['reward_mean'] + meta_infos['train_iter'] = train_iter + meta_infos['ckpt_name'] = 'iteration_{}.pth.tar'.format(train_iter) + + self._logger.info(self._logger.get_tabulate_vars_hor(meta_infos)) + # self._logger.info(self._logger.get_tabulate_vars(info)) + for k, v in meta_infos.items(): + if k in ['train_iter', 'ckpt_name', 'each_reward']: + continue + if not np.isscalar(v): + continue + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + + if episode_return > self._max_episode_return: + if save_ckpt_fn: + save_ckpt_fn('ckpt_best.pth.tar') + self._max_episode_return = episode_return + + stop_flag = episode_return >= self._stop_value and train_iter > 0 + if stop_flag: + self._logger.info( + "[DI-engine serial pipeline] " + "Current episode_return: {:.4f} is greater than stop_value: {}". + format(episode_return, self._stop_value) + ", so your RL agent is converged, you can refer to " + + "'log/evaluator/evaluator_logger.txt' for details." + ) + + return stop_flag, meta_infos + + def sub_eval( + self, + save_ckpt_fn: Callable = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + force_render: bool = False, + policy_kwargs: Optional[Dict] = {}, + task_id: int = 0, + ) -> Tuple[bool, Dict[str, List]]: + ''' + Overview: + Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Arguments: + - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. + - train_iter (:obj:`int`): Current training iteration. + - envstep (:obj:`int`): Current env interaction step. + - n_episode (:obj:`int`): Number of evaluation episodes. + Returns: + - stop_flag (:obj:`bool`): Whether this training program can be ended. + - episode_info (:obj:`Dict[str, List]`): Current evaluation episode information. + ''' + # evaluator only work on rank0 + stop_flag = False + if get_rank() == 0: + if n_episode is None: + n_episode = self._default_n_episode + assert n_episode is not None, "please indicate eval n_episode" + envstep_count = 0 + info = {} + eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode) + self._env.reset() + self._policy.reset() + + # force_render overwrite frequency constraint + render = force_render or self._should_render(envstep, train_iter) + + with self._timer: + while not eval_monitor.is_finished(): + obs = self._env.ready_obs + obs = to_tensor(obs, dtype=torch.float32) + + # update videos + if render: + eval_monitor.update_video(self._env.ready_imgs) + + if self._policy_cfg.type == 'dreamer_command': + policy_output = self._policy.forward( + obs, **policy_kwargs, reset=self._resets, state=self._states + ) + #self._states = {env_id: output['state'] for env_id, output in policy_output.items()} + self._states = [output['state'] for output in policy_output.values()] + else: + policy_output = self._policy.forward(obs, **policy_kwargs) + actions = {i: a['action'] for i, a in policy_output.items()} + actions = to_ndarray(actions) + timesteps = self._env.step(actions) + timesteps = to_tensor(timesteps, dtype=torch.float32) + for env_id, t in timesteps.items(): + if t.info.get('abnormal', False): + # If there is an abnormal timestep, reset all the related variables(including this env). + self._policy.reset([env_id]) + continue + if self._policy_cfg.type == 'dreamer_command': + self._resets[env_id] = t.done + if t.done: + # Env reset is done by env_manager automatically. + if 'figure_path' in self._cfg and self._cfg.figure_path is not None: + self._env.enable_save_figure(env_id, self._cfg.figure_path) + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'] + saved_info = {'eval_episode_return': t.info['eval_episode_return']} + if 'episode_info' in t.info: + saved_info.update(t.info['episode_info']) + eval_monitor.update_info(env_id, saved_info) + eval_monitor.update_reward(env_id, reward) + self._logger.info( + "[EVALUATOR]env {} finish task {} episode, final reward: {:.4f}, current episode: {}".format( + env_id, task_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) + ) + envstep_count += 1 + duration = self._timer.value + episode_return = eval_monitor.get_episode_return() + info = { + 'episode_count': n_episode, + 'envstep_count': envstep_count, + 'avg_envstep_per_episode': envstep_count / n_episode, + 'evaluate_time': duration, + 'avg_envstep_per_sec': envstep_count / duration, + 'avg_time_per_episode': n_episode / duration, + 'reward_mean': np.mean(episode_return), + 'reward_std': np.std(episode_return), + 'reward_max': np.max(episode_return), + 'reward_min': np.min(episode_return), + # 'each_reward': episode_return, + } + episode_info = eval_monitor.get_episode_info() + if episode_info is not None: + info.update(episode_info) + + if render: + video_title = '{}_{}/'.format(self._instance_name, self._render.mode) + videos = eval_monitor.get_video() + render_iter = envstep if self._render.mode == 'envstep' else train_iter + from ding.utils import fps + self._tb_logger.add_video(video_title, videos, render_iter, fps(self._env)) + + return info diff --git a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py old mode 100644 new mode 100755 index efba4bf2e0..f017bd8419 --- a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py +++ b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py @@ -1,9 +1,9 @@ from easydict import EasyDict main_config = dict( - exp_name="walker2d_medium_expert_pd_seed0", + exp_name="walker_params_md_seed0", env=dict( - env_id='walker2d-medium-expert-v2', + env_id='walker_params', collector_env_num=1, evaluator_env_num=8, use_act_scale=True, @@ -18,22 +18,27 @@ horizon=32, obs_dim=17, action_dim=6, + test_num=10, ), policy=dict( cuda=True, - max_len=1000, - max_ep_len=1000, + max_len=32, + max_ep_len=200, task_num=40, + train_num=1, obs_dim=17, - action_dim=6, + act_dim=6, + no_state_normalize=False, + no_action_normalize=False, + need_init_dataprocess=True, model=dict( diffuser_model_cfg=dict( model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, - dim=32, - dim_mults=[1, 2, 4, 8], - returns_condition=False, + dim=64, + dim_mults=[1, 4, 8], + returns_condition=True, kernel_size=5, attention=False, ), @@ -42,17 +47,20 @@ action_dim=6, n_timesteps=20, predict_epsilon=False, + condition_guidance_w=1.2, loss_discount=1, - action_weight=10, ), reward_cfg=dict( model='TemporalValue', model_cfg=dict( horizon = 32, transition_dim=23, - dim=32, - dim_mults=[1, 2, 4, 8], + dim=64, + out_dim=32, + dim_mults=[1, 4, 8], kernel_size=5, + returns_condition=True, + no_need_ret_sin=True, ), horizon=32, obs_dim=17, @@ -61,7 +69,7 @@ predict_epsilon=True, loss_discount=1, ), - horizon=80, + horizon=32, n_guide_steps=2, scale=0.1, t_stopgrad=2, @@ -76,26 +84,31 @@ learning_rate=2e-4, discount_factor=0.99, learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), + eval_batch_size=8, + warm_batch_size=640, + test_num=10, ), - collect=dict(data_type='diffuser_traj', ), + collect=dict(data_type='meta_traj', ), eval=dict( evaluator=dict( eval_freq=500, - test_env_list=[5,10,22,31,18,1,12,9,25,38], + test_env_num=10, ), test_ret=0.9, ), - dateset=dict( - data_dir_prefix=1, - rtg_scale=1, - context_len=1, - stochastic_prompt=False, - need_prompt=False, - test_id=[5,10,22,31,18,1,12,9,25,38], - cond=True - ), other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ), ), + dataset=dict( + data_dir_prefix='/mnt/nfs/share/meta/walker_traj/buffers_walker_param_train', + rtg_scale=1, + context_len=1, + stochastic_prompt=False, + need_prompt=False, + test_id=[5,10,22,31,18,1,12,9,25,38], + cond=True, + env_param_path='/mnt/nfs/share/meta/walker/env_walker_param_train_task', + need_next_obs=True, + ), ) main_config = EasyDict(main_config) @@ -103,12 +116,12 @@ create_config = dict( env=dict( - type='d4rl', - import_names=['dizoo.d4rl.envs.d4rl_env'], + type='meta', + import_names=['dizoo.meta_mujoco.envs.meta_env'], ), - env_manager=dict(type='subprocess'), + env_manager=dict(type='meta_subprocess'), policy=dict( - type='pd', + type='metadiffuser', ), replay_buffer=dict(type='naive', ), ) diff --git a/dizoo/meta_mujoco/config/walker2d_promptdt_config.py b/dizoo/meta_mujoco/config/walker2d_promptdt_config.py new file mode 100755 index 0000000000..2c25d50e32 --- /dev/null +++ b/dizoo/meta_mujoco/config/walker2d_promptdt_config.py @@ -0,0 +1,107 @@ +from easydict import EasyDict +from copy import deepcopy + +main_config = dict( + exp_name='walker_params_promptdt_seed0', + env=dict( + env_id='walker_params', + collector_env_num=1, + evaluator_env_num=8, + use_act_scale=True, + n_evaluator_episode=8, + returns_scale=1.0, + termination_penalty=-100, + max_path_length=1000, + use_padding=True, + include_returns=True, + normed=False, + stop_value=8000, + horizon=32, + obs_dim=17, + action_dim=6, + test_num=1, + ), + dataset=dict( + data_dir_prefix='/mnt/nfs/share/meta/walker_traj/buffers_walker_param_train', + rtg_scale=1, + context_len=1, + stochastic_prompt=False, + need_prompt=True, + test_id=[1],#[5,10,22,31,18,1,12,9,25,38], + cond=False, + env_param_path='/mnt/nfs/share/meta/walker/env_walker_param_train_task', + need_next_obs=False, + ), + policy=dict( + cuda=True, + stop_value=5000, + max_len=20, + max_ep_len=200, + task_num=3, + train_num=1, + obs_dim=17, + act_dim=6, + state_mean=None, + state_std=None, + no_state_normalize=False, + no_action_normalize=True, + need_init_dataprocess=True, + evaluator_env_num=8, + rtg_target=5000, # max target return to go + max_eval_ep_len=1000, # max lenght of one episode + wt_decay=1e-4, + warmup_steps=10000, + context_len=20, + weight_decay=0.1, + clip_grad_norm_p=0.25, + model=dict( + state_dim=17, + act_dim=6, + n_blocks=3, + h_dim=128, + context_len=20, + n_heads=1, + drop_p=0.1, + continuous=True, + use_prompt=True, + ), + batch_size=32, + learning_rate=1e-4, + collect=dict(data_type='meta_traj', ), + learn=dict( + data_path=None, + train_epoch=60000, + gradient_accumulate_every=2, + batch_size=32, + learning_rate=1e-4, + discount_factor=0.99, + learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), + eval_batch_size=8, + test_num=1, + ), + eval=dict( + evaluator=dict( + eval_freq=500, + test_env_num=1, + ), + test_ret=0.9, + ), + ), +) + +main_config = EasyDict(main_config) +main_config = main_config + +create_config = dict( + env=dict( + type='meta', + import_names=['dizoo.meta_mujoco.envs.meta_env'], + ), + env_manager=dict(type='meta_subprocess'), + policy=dict( + type='promptdt', + ), + replay_buffer=dict(type='naive', ), +) +create_config = EasyDict(create_config) +create_config = create_config \ No newline at end of file diff --git a/dizoo/meta_mujoco/envs/meta_env.py b/dizoo/meta_mujoco/envs/meta_env.py index 577363dde1..32bb3c25b4 100755 --- a/dizoo/meta_mujoco/envs/meta_env.py +++ b/dizoo/meta_mujoco/envs/meta_env.py @@ -1,9 +1,12 @@ from typing import Any, Union, List import copy import gym +import numpy as np from easydict import EasyDict -from CORRO.environments.make_env import make_env +#from CORRO.environments.make_env import make_env + +from rand_param_envs.make_env import make_env from ding.torch_utils import to_ndarray, to_list from ding.envs import BaseEnv, BaseEnvTimestep @@ -20,7 +23,7 @@ def __init__(self, cfg: dict) -> None: def reset(self) -> Any: if not self._init_flag: - self._env = make_env(self._cfg.env_id, 1, seed=self._cfg.seed) + self._env = make_env(self._cfg.env_id, 1, seed=self._cfg.seed, n_tasks=self._cfg.test_num) self._env.observation_space.dtype = np.float32 self._observation_space = self._env.observation_space self._action_space = self._env.action_space @@ -54,6 +57,12 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: def __repr__(self) -> str: return "DI-engine D4RL Env({})".format(self._cfg.env_id) + def set_all_goals(self, params): + self._env.set_all_goals(params) + + def reset_task(self, id): + self._env.reset_task(id) + @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_cfg = copy.deepcopy(cfg) @@ -66,6 +75,11 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1) evaluator_cfg.get('norm_reward', EasyDict(use_norm=False, )).use_norm = False return [evaluator_cfg for _ in range(evaluator_env_num)] + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) @property def observation_space(self) -> gym.spaces.Space: From 16e81448f86b3e605fde1723f150815e858c0ee7 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Sat, 20 Jan 2024 16:23:31 +0800 Subject: [PATCH 06/16] change --- ding/envs/env_manager/__init__.py | 3 ++- .../env_manager/subprocess_env_manager.py | 20 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) mode change 100644 => 100755 ding/envs/env_manager/__init__.py mode change 100644 => 100755 ding/envs/env_manager/subprocess_env_manager.py diff --git a/ding/envs/env_manager/__init__.py b/ding/envs/env_manager/__init__.py old mode 100644 new mode 100755 index 62d45baf27..415f63bc3d --- a/ding/envs/env_manager/__init__.py +++ b/ding/envs/env_manager/__init__.py @@ -1,5 +1,6 @@ from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls -from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2 +from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2,\ + MetaSyncSubprocessEnvManager from .gym_vector_env_manager import GymVectorEnvManager # Do not import PoolEnvManager here, because it depends on installation of `envpool` from .env_supervisor import EnvSupervisor diff --git a/ding/envs/env_manager/subprocess_env_manager.py b/ding/envs/env_manager/subprocess_env_manager.py old mode 100644 new mode 100755 index 5a391f3932..2bf2219bad --- a/ding/envs/env_manager/subprocess_env_manager.py +++ b/ding/envs/env_manager/subprocess_env_manager.py @@ -832,3 +832,23 @@ def step(self, actions: Union[List[tnp.ndarray], tnp.ndarray]) -> List[tnp.ndarr info = remove_illegal_item(info) new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id})) return new_data + +@ENV_MANAGER_REGISTRY.register('meta_subprocess') +class MetaSyncSubprocessEnvManager(SyncSubprocessEnvManager): + + @property + def method_name_list(self) -> list: + return [ + 'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure', + 'set_all_goals', 'reset_task' + ] + + def set_all_goals(self, params): + for p in self._pipe_parents.values(): + p.send(['set_all_goals', [params], {}]) + data = {i: p.recv() for i, p in self._pipe_parents.items()} + + def reset_task(self, id): + for p in self._pipe_parents.values(): + p.send(['reset_task', [id], {}]) + data = {i: p.recv() for i, p in self._pipe_parents.items()} From 3524c720f15d6f5f1216e646abbc32123aac06e9 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Sun, 21 Jan 2024 14:04:48 +0800 Subject: [PATCH 07/16] add init --- ding/worker/collector/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ding/worker/collector/__init__.py b/ding/worker/collector/__init__.py index 8ccfb17260..1b06f20aa8 100644 --- a/ding/worker/collector/__init__.py +++ b/ding/worker/collector/__init__.py @@ -16,3 +16,4 @@ from .zergling_parallel_collector import ZerglingParallelCollector from .marine_parallel_collector import MarineParallelCollector from .comm import BaseCommCollector, FlaskFileSystemCollector, create_comm_collector, NaiveCollector +from .interaction_serial_meta_evaluator import InteractionSerialMetaEvaluator From b0e727490919681087e0b605c5329a00fe5285bd Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Sun, 21 Jan 2024 14:42:11 +0800 Subject: [PATCH 08/16] add init --- ding/entry/__init__.py | 1 + dizoo/meta_mujoco/entry/meta_entry.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) mode change 100644 => 100755 ding/entry/__init__.py mode change 100644 => 100755 dizoo/meta_mujoco/entry/meta_entry.py diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py old mode 100644 new mode 100755 index 11cccf0e13..bd4b6baa09 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -26,3 +26,4 @@ from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer from .serial_entry_bco import serial_pipeline_bco from .serial_entry_pc import serial_pipeline_pc +from .serial_entry_meta_offline import serial_pipeline_meta_offline diff --git a/dizoo/meta_mujoco/entry/meta_entry.py b/dizoo/meta_mujoco/entry/meta_entry.py old mode 100644 new mode 100755 index d11844521b..18ede61c16 --- a/dizoo/meta_mujoco/entry/meta_entry.py +++ b/dizoo/meta_mujoco/entry/meta_entry.py @@ -1,4 +1,4 @@ -from ding.entry import serial_entry_meta_offline +from ding.entry import serial_pipeline_meta_offline from ding.config import read_config from pathlib import Path @@ -8,7 +8,7 @@ def train(args): config = Path(__file__).absolute().parent.parent / 'config' / args.config config = read_config(str(config)) config[0].exp_name = config[0].exp_name.replace('0', str(args.seed)) - serial_entry_meta_offline(config, seed=args.seed) + serial_pipeline_meta_offline(config, seed=args.seed) if __name__ == "__main__": From 6be59206e65244d399e8f8d34339978437115d75 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Sun, 21 Jan 2024 22:10:12 +0800 Subject: [PATCH 09/16] add --- ding/model/template/diffusion.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index cc6479f79c..c6be84fa81 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -93,9 +93,6 @@ def free_guidance_sample( ): weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) - model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) - model_std = torch.exp(0.5 * model_log_variance) - model_var = torch.exp(model_log_variance) for _ in range(n_guide_steps): with torch.enable_grad(): @@ -117,7 +114,16 @@ def free_guidance_sample( epsilon = model.model(x, cond, t) epsilon += grad - model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, epsilon=epsilon) + t = t.detach().to(torch.int64) + x_recon = model.predict_start_from_noise(x, t=t, noise=epsilon) + + if model.clip_denoised: + x_recon.clamp_(-1., 1.) + else: + assert RuntimeError() + + model_mean, _, model_log_variance = model.p_mean_variance(x=x_recon, cond=cond, t=t, epsilon=epsilon) + model_std = torch.exp(0.5 * model_log_variance) noise = torch.randn_like(x) noise[t == 0] = 0 From 7519400d816b181986a22f8000ed634e20d3e36e Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Mon, 22 Jan 2024 15:07:48 +0800 Subject: [PATCH 10/16] debug --- ding/model/template/diffusion.py | 29 ++++++++++--------- ding/policy/meta_diffuser.py | 2 +- .../config/walker2d_metadiffuser_config.py | 3 +- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index c6be84fa81..91c8352b0f 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -93,6 +93,8 @@ def free_guidance_sample( ): weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) + # model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) + # model_std = torch.exp(0.5 * model_log_variance) for _ in range(n_guide_steps): with torch.enable_grad(): @@ -114,20 +116,12 @@ def free_guidance_sample( epsilon = model.model(x, cond, t) epsilon += grad - t = t.detach().to(torch.int64) - x_recon = model.predict_start_from_noise(x, t=t, noise=epsilon) - - if model.clip_denoised: - x_recon.clamp_(-1., 1.) - else: - assert RuntimeError() - - model_mean, _, model_log_variance = model.p_mean_variance(x=x_recon, cond=cond, t=t, epsilon=epsilon) + model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, epsilon=epsilon) model_std = torch.exp(0.5 * model_log_variance) noise = torch.randn_like(x) noise[t == 0] = 0 - return model_mean + model_std * noise + return model_mean + model_std * noise, y1 class GaussianDiffusion(nn.Module): """ @@ -740,10 +734,16 @@ def p_sample_loop(self, shape, cond, sample_fn=None, plan_size=1, **sample_kwarg assert sample_fn != None for i in reversed(range(0, self.n_timesteps)): t = torch.full((batch_size, ), i, device=device, dtype=torch.long) - x = sample_fn(self, x, cond, t, **sample_kwargs) + x, values = sample_fn(self, x, cond, t, **sample_kwargs) x = apply_conditioning(x, cond, self.action_dim) - return x + values = values.reshape(-1, plan_size, *values.shape[1:]) + x = x.reshape(-1, plan_size, *x.shape[1:]) + if plan_size > 1: + inds = torch.argsort(values, dim=1, descending=True) + inds = inds.unsqueeze(-1).expand_as(x) + x = x.gather(1, inds) + return x[:,0] def conditional_sample(self, cond, horizon=None, **sample_kwargs): @@ -861,9 +861,12 @@ def pre_train_loss(self, traj, target, t, cond): return state_loss, reward_loss, reward_log def get_eval(self, cond, id, batch_size = 1): + id = torch.stack(id, dim=0) if batch_size > 1: cond = self.repeat_cond(cond, batch_size) - id = torch.stack(id, dim=0) + id = id.unsqueeze(1).repeat_interleave(batch_size, dim=1) + id = id.reshape(-1, id.shape[-1]) + samples = self.diffuser(cond, returns=id, sample_fn=free_guidance_sample, plan_size=batch_size, guide1=self.reward_model, guide2=self.dynamic_model, **self.sample_kwargs) return samples[:, 0, :self.action_dim] diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py index af01363b64..cf1b9be665 100755 --- a/ding/policy/meta_diffuser.py +++ b/ding/policy/meta_diffuser.py @@ -342,7 +342,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: if self._cuda: obs = to_device(obs, self._device) conditions = {0: obs} - action = self._eval_model.get_eval(conditions, self.test_task_id) + action = self._eval_model.get_eval(conditions, self.test_task_id, self._cfg.learn.plan_batch_size) if self._cuda: action = to_device(action, 'cpu') for i in range(self.eval_batch_size): diff --git a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py index f017bd8419..f73e0d854f 100755 --- a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py +++ b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py @@ -85,8 +85,9 @@ discount_factor=0.99, learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), eval_batch_size=8, - warm_batch_size=640, + warm_batch_size=32, test_num=10, + plan_batch_size=1, ), collect=dict(data_type='meta_traj', ), eval=dict( From 3bafbf10ea5561c93eb58564b13e57926c621f9a Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Mon, 22 Jan 2024 16:37:42 +0800 Subject: [PATCH 11/16] change pdt --- ding/model/template/decision_transformer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 4eea2ca01e..9807493a52 100755 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -303,9 +303,7 @@ def forward( # time embeddings are treated similar to positional embeddings state_embeddings = self.embed_state(states) + time_embeddings action_embeddings = self.embed_action(actions) + time_embeddings - returns_embeddings = self.embed_rtg(returns_to_go) - returns_embeddings += time_embeddings - + returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings # stack rtg, states and actions and reshape sequence as # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) @@ -331,18 +329,18 @@ def forward( prompt_stacked_inputs = torch.stack( (prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 ).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) - prompt_stacked_attention_mask = torch.stack( - (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 - ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length) + # prompt_stacked_attention_mask = torch.stack( + # (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 + # ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length) if prompt_stacked_inputs.shape[1] == 3 * T: # if only smaple one prompt prompt_stacked_inputs = prompt_stacked_inputs.reshape(1, -1, self.h_dim) - prompt_stacked_attention_mask = prompt_stacked_attention_mask.reshape(1, -1) + #prompt_stacked_attention_mask = prompt_stacked_attention_mask.reshape(1, -1) h = torch.cat((prompt_stacked_inputs.repeat(B, 1, 1), h), dim=1) - stacked_attention_mask = torch.cat((prompt_stacked_attention_mask.repeat(B, 1), stacked_attention_mask), dim=1) + # stacked_attention_mask = torch.cat((prompt_stacked_attention_mask.repeat(B, 1), stacked_attention_mask), dim=1) else: # if sample one prompt for each traj in batch h = torch.cat((prompt_stacked_inputs, h), dim=1) - stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) + # stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) # transformer and prediction h = self.transformer(h) From 2b1bdaacc36cbfee80e83a3c8f5a2fe392aefd7b Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Mon, 22 Jan 2024 20:13:28 +0800 Subject: [PATCH 12/16] add comman --- ding/policy/command_mode_policy_instance.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py index 2d5e3271dd..305c757901 100644 --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -51,6 +51,8 @@ from .edac import EDACPolicy from .prompt_pg import PromptPGPolicy from .plan_diffuser import PDPolicy +from .meta_diffuser import MDPolicy +from .prompt_dt import PDTPolicy class EpsCommandModePolicy(CommandModePolicy): @@ -449,3 +451,11 @@ def _get_setting_eval(self, command_info: dict) -> dict: @POLICY_REGISTRY.register('prompt_pg_command') class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): pass + +@POLICY_REGISTRY.register('metadiffuser_command') +class MDCommandModePolicy(MDPolicy, DummyCommandModePolicy): + pass + +@POLICY_REGISTRY.register('promptdt_command') +class PDTCommandModePolicy(PDTPolicy, DummyCommandModePolicy): + pass From c8d9c7f168309960b4d415c5e7f5bd2fd852d9b7 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Wed, 24 Jan 2024 10:35:31 +0800 Subject: [PATCH 13/16] metadiffuser --- ding/entry/serial_entry_meta_offline.py | 6 +- ding/model/template/diffusion.py | 8 +- ding/policy/meta_diffuser.py | 103 ++++++++++++------ ding/policy/prompt_dt.py | 7 +- ding/utils/data/dataset.py | 1 + .../interaction_serial_meta_evaluator.py | 5 +- 6 files changed, 86 insertions(+), 44 deletions(-) diff --git a/ding/entry/serial_entry_meta_offline.py b/ding/entry/serial_entry_meta_offline.py index 89bc43a6ed..544b425b24 100755 --- a/ding/entry/serial_entry_meta_offline.py +++ b/ding/entry/serial_entry_meta_offline.py @@ -103,7 +103,11 @@ def serial_pipeline_meta_offline( # Evaluate policy at most once per epoch. if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + if hasattr(policy, 'warm_train'): + # if algorithm need warm train + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, policy_warm_func=policy.warm_train) + else: + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) if stop or learner.train_iter >= max_train_iter: stop = True diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index 91c8352b0f..5b6bae54ba 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -93,7 +93,7 @@ def free_guidance_sample( ): weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) - # model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) + model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) # model_std = torch.exp(0.5 * model_log_variance) for _ in range(n_guide_steps): @@ -794,6 +794,7 @@ def __init__( reward_cfg: dict, diffuser_model_cfg: dict, horizon: int, + encoder_horizon: int, **sample_kwargs, ): super().__init__() @@ -802,9 +803,10 @@ def __init__( self.action_dim = action_dim self.horizon = horizon self.sample_kwargs = sample_kwargs + self.encoder_horizon = encoder_horizon self.embed = nn.Sequential( - nn.Linear((obs_dim * 2 + action_dim + 1) * horizon, dim * 4), + nn.Linear((obs_dim * 2 + action_dim + 1) * encoder_horizon, dim * 4), #nn.Mish(), Mish(), nn.Linear(dim * 4, dim * 4), @@ -850,7 +852,7 @@ def pre_train_loss(self, traj, target, t, cond): reward_loss, reward_log = self.reward_model.p_losses(input, cond, target_reward, t, task_idx) - task_idxs = task_idx.unsqueeze(1).repeat_interleave(self.horizon,dim=1) + task_idxs = task_idx.unsqueeze(1).repeat_interleave(self.encoder_horizon, dim=1) input = torch.cat([input, task_idxs], dim=-1) input = input.reshape(-1, input.shape[-1]) diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py index cf1b9be665..ef3670f321 100755 --- a/ding/policy/meta_diffuser.py +++ b/ding/policy/meta_diffuser.py @@ -181,6 +181,7 @@ def _init_learn(self) -> None: self.test_num = self._cfg.learn.test_num self.have_train = False self._forward_learn_cnt = 0 + self.encoder_len = self._cfg.learn.encoder_len self._plan_optimizer = Adam( self._model.diffuser.model.parameters(), @@ -231,6 +232,8 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: pre_traj = torch.cat([acts, obs, rewards, next_obs], dim=-1).to(self._device) target = torch.cat([next_obs, rewards], dim=-1).to(self._device) traj = torch.cat([acts, obs], dim=-1).to(self._device) + pre_traj = pre_traj[:, :self.encoder_len] + target = pre_traj[:, :self.encoder_len] batch_size = len(traj) t = torch.randint(0, self.n_timesteps, (batch_size, ), device=traj.device).long() @@ -333,7 +336,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self._eval_model.eval() obs = [] - for i in range(self.eval_batch_size): + for i in range(len(data)): if not self._cfg.no_state_normalize: obs.append(self.dataloader.normalize(data[i], 'obs', self.task_id[i])) @@ -342,53 +345,87 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: if self._cuda: obs = to_device(obs, self._device) conditions = {0: obs} - action = self._eval_model.get_eval(conditions, self.test_task_id, self._cfg.learn.plan_batch_size) + action = self._eval_model.get_eval(conditions, self.test_task_id[:len(data)], self._cfg.learn.plan_batch_size) if self._cuda: action = to_device(action, 'cpu') - for i in range(self.eval_batch_size): + for i in range(len(data)): if not self._cfg.no_action_normalize: action[i] = self.dataloader.unnormalize(action[i], 'actions', self.task_id[i]) action = torch.tensor(action).to('cpu') output = {'action': action} output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} + + def warm_train(self, id: int): + self.task_id = [id] * self.eval_batch_size + obs, acts, rewards, cond_ids, cond_vals = \ + self.dataloader.get_pretrain_data(id, self.warm_batch_size) + obs = to_device(obs, self._device) + acts = to_device(acts, self._device) + rewards = to_device(rewards, self._device) + cond_vals = to_device(cond_vals, self._device) + + obs, next_obs = obs[:, :-1], obs[:, 1:] + acts = acts[:, :-1] + rewards = rewards[:, :-1] + + pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=-1) + target = torch.cat([next_obs, rewards], dim=-1) + batch_size = len(pre_traj) + conds = {cond_ids: cond_vals} + pre_traj = pre_traj[:, :self.encoder_len] + target = pre_traj[:, :self.encoder_len] + + t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + total_loss = state_loss + reward_loss + self._pre_train_optimizer.zero_grad() + total_loss.backward() + self._pre_train_optimizer.step() + self.update_model_average(self._target_model, self._learn_model) + self.test_task_id = [self._target_model.get_task_id(pre_traj)[0]] * self.eval_batch_size + def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: if self.have_train: - if data_id is None: - data_id = list(range(self.eval_batch_size)) - if self.task_id is not None: - for id in data_id: - self.task_id[id] = (self.task_id[id] + 1) % self.test_num - else: + if self.task_id is None: self.task_id = [0] * self.eval_batch_size + # if data_id is None: + # data_id = list(range(self.eval_batch_size)) + # if self.task_id is not None: + # for id in data_id: + # self.task_id[id] = (self.task_id[id] + 1) % self.test_num + # else: + # self.task_id = [0] * self.eval_batch_size - for id in data_id: - obs, acts, rewards, cond_ids, cond_vals = \ - self.dataloader.get_pretrain_data(self.task_id[id], self.warm_batch_size) - obs = to_device(obs, self._device) - acts = to_device(acts, self._device) - rewards = to_device(rewards, self._device) - cond_vals = to_device(cond_vals, self._device) - - obs, next_obs = obs[:, :-1], obs[:, 1:] - acts = acts[:, :-1] - rewards = rewards[:, :-1] + # for id in data_id: + # obs, acts, rewards, cond_ids, cond_vals = \ + # self.dataloader.get_pretrain_data(self.task_id[id], self.warm_batch_size) + # obs = to_device(obs, self._device) + # acts = to_device(acts, self._device) + # rewards = to_device(rewards, self._device) + # cond_vals = to_device(cond_vals, self._device) + + # obs, next_obs = obs[:, :-1], obs[:, 1:] + # acts = acts[:, :-1] + # rewards = rewards[:, :-1] - pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=-1) - target = torch.cat([next_obs, rewards], dim=-1) - batch_size = len(pre_traj) - conds = {cond_ids: cond_vals} - - t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() - state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) - total_loss = state_loss + reward_loss - self._pre_train_optimizer.zero_grad() - total_loss.backward() - self._pre_train_optimizer.step() - self.update_model_average(self._target_model, self._learn_model) + # pre_traj = torch.cat([acts, obs, next_obs, rewards], dim=-1) + # target = torch.cat([next_obs, rewards], dim=-1) + # batch_size = len(pre_traj) + # conds = {cond_ids: cond_vals} + # pre_traj = pre_traj[:, :self.encoder_len] + # target = pre_traj[:, :self.encoder_len] + + # t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() + # state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + # total_loss = state_loss + reward_loss + # self._pre_train_optimizer.zero_grad() + # total_loss.backward() + # self._pre_train_optimizer.step() + # self.update_model_average(self._target_model, self._learn_model) - self.test_task_id[id] = self._target_model.get_task_id(pre_traj)[0] + # self.test_task_id[id] = self._target_model.get_task_id(pre_traj)[0] def _init_collect(self) -> None: pass diff --git a/ding/policy/prompt_dt.py b/ding/policy/prompt_dt.py index 91dd4dd8d6..b485e984a7 100755 --- a/ding/policy/prompt_dt.py +++ b/ding/policy/prompt_dt.py @@ -230,10 +230,5 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: if self.have_train: - if data_id is None: - data_id = list(range(self.eval_batch_size)) - if self.task_id is not None: - for id in data_id: - self.task_id[id] = (self.task_id[id] + 1) % self.test_num - else: + if self.task_id is None: self.task_id = [0] * self.eval_batch_size \ No newline at end of file diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index 58c19a71f7..c006d68c51 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1299,6 +1299,7 @@ def unnormalize(self, data: np.array, type: str, task_id: int): # get warm start data def get_pretrain_data(self, task_id: int, batch_size: int): # get warm data + print('task_id:',task_id) trajs = self.test_traj[task_id] batch_idx = np.random.choice( np.arange(len(trajs)), diff --git a/ding/worker/collector/interaction_serial_meta_evaluator.py b/ding/worker/collector/interaction_serial_meta_evaluator.py index 23269a1d50..37bd312835 100755 --- a/ding/worker/collector/interaction_serial_meta_evaluator.py +++ b/ding/worker/collector/interaction_serial_meta_evaluator.py @@ -61,11 +61,14 @@ def eval( n_episode: Optional[int] = None, force_render: bool = False, policy_kwargs: Optional[Dict] = {}, + policy_warm_func: namedtuple = None, ) -> Tuple[bool, Dict[str, List]]: infos = defaultdict(list) for i in range(self.test_env_num): print('-----------------------------start task ', i) self._env.reset_task(i) + if policy_warm_func is not None: + policy_warm_func(i) info = self.sub_eval(save_ckpt_fn, train_iter, envstep, n_episode, \ force_render, policy_kwargs, i) for key, val in info.items(): @@ -74,7 +77,7 @@ def eval( infos[key].append(val) meta_infos = defaultdict(list) - for key, val in info.items(): + for key, val in infos.items(): meta_infos[key] = np.array(val).mean() episode_return = meta_infos['reward_mean'] From fd2896c69c2fa1688d4a265ff020b1f26447bdb4 Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Mon, 29 Jan 2024 11:36:23 +0800 Subject: [PATCH 14/16] debug --- ding/entry/serial_entry_meta_offline.py | 6 +- ding/model/template/decision_transformer.py | 30 +-- ding/model/template/diffusion.py | 93 ++++--- ding/policy/meta_diffuser.py | 40 +-- ding/policy/prompt_dt.py | 233 +++++++++++------- ding/torch_utils/network/diffusion.py | 8 +- .../interaction_serial_meta_evaluator.py | 17 +- .../config/walker2d_metadiffuser_config.py | 25 +- .../config/walker2d_promptdt_config.py | 14 +- 9 files changed, 293 insertions(+), 173 deletions(-) diff --git a/ding/entry/serial_entry_meta_offline.py b/ding/entry/serial_entry_meta_offline.py index 544b425b24..1759e5a9c1 100755 --- a/ding/entry/serial_entry_meta_offline.py +++ b/ding/entry/serial_entry_meta_offline.py @@ -105,9 +105,11 @@ def serial_pipeline_meta_offline( if evaluator.should_eval(learner.train_iter): if hasattr(policy, 'warm_train'): # if algorithm need warm train - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, policy_warm_func=policy.warm_train) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, + policy_warm_func=policy.warm_train, need_reward=cfg.policy.need_reward) else: - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, + need_reward=cfg.policy.need_reward) if stop or learner.train_iter >= max_train_iter: stop = True diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 9807493a52..9c6ed8019e 100755 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -209,6 +209,7 @@ def __init__( self.embed_timestep = nn.Embedding(max_timestep, h_dim) if use_prompt: self.prompt_embed_timestep = nn.Embedding(max_timestep, h_dim) + input_seq_len *= 2 self.drop = nn.Dropout(drop_p) self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) @@ -310,7 +311,7 @@ def forward( t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) h = self.embed_ln(t_p) - + if prompt is not None: prompt_states, prompt_actions, prompt_returns_to_go,\ prompt_timesteps, prompt_attention_mask = prompt @@ -329,20 +330,15 @@ def forward( prompt_stacked_inputs = torch.stack( (prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 ).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) + # prompt_stacked_attention_mask = torch.stack( # (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 - # ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length) - - if prompt_stacked_inputs.shape[1] == 3 * T: # if only smaple one prompt - prompt_stacked_inputs = prompt_stacked_inputs.reshape(1, -1, self.h_dim) - #prompt_stacked_attention_mask = prompt_stacked_attention_mask.reshape(1, -1) - h = torch.cat((prompt_stacked_inputs.repeat(B, 1, 1), h), dim=1) - # stacked_attention_mask = torch.cat((prompt_stacked_attention_mask.repeat(B, 1), stacked_attention_mask), dim=1) - else: # if sample one prompt for each traj in batch - h = torch.cat((prompt_stacked_inputs, h), dim=1) - # stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) + # ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length + h = torch.cat((prompt_stacked_inputs, h), dim=1) + # stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) # transformer and prediction + h = self.transformer(h) # get h reshaped such that its size = (B x 3 x T x h_dim) and # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t @@ -351,11 +347,15 @@ def forward( # that is, for each timestep (t) we have 3 output embeddings from the transformer, # each conditioned on all previous timesteps plus # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. - h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) + if prompt is None: + h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) + else: + h = h.reshape(B, -1, 3, self.h_dim).permute(0, 2, 1, 3) + + return_preds = self.predict_rtg(h[:, 2])[:, -T:, :] # predict next rtg given r, s, a + state_preds = self.predict_state(h[:, 2])[:, -T:, :] # predict next state given r, s, a + action_preds = self.predict_action(h[:, 1])[:, -T:, :] # predict action given r, s - return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a - state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a - action_preds = self.predict_action(h[:, 1]) # predict action given r, s else: state_embeddings = self.state_encoder( states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index 5b6bae54ba..ff703a92cb 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -26,7 +26,7 @@ def default_sample_fn(model, x, cond, t): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, values -def get_guide_output(guide, x, cond, t, returns=None, is_dynamic=False): +def get_guide_output(guide, x, cond, t, returns=None, is_dynamic=False, act_dim=6): x.requires_grad_() if returns is not None: if not is_dynamic: @@ -36,6 +36,8 @@ def get_guide_output(guide, x, cond, t, returns=None, is_dynamic=False): input = torch.cat([x, returns], dim=-1) input = input.reshape(-1, input.shape[-1]) y = guide(input) + y = y.reshape(x.shape[0], x.shape[1], -1) + y = F.mse_loss(x[:, 1:, act_dim:], y[:, :-1], reduction='none') else: y = guide(x, cond, t).squeeze(dim=-1) grad = torch.autograd.grad([y.sum()], [x])[0] @@ -94,30 +96,32 @@ def free_guidance_sample( ): weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) - # model_std = torch.exp(0.5 * model_log_variance) + model_std = torch.exp(0.5 * model_log_variance) + model_var = torch.exp(model_log_variance) for _ in range(n_guide_steps): with torch.enable_grad(): y1, grad1 = get_guide_output(guide1, x, cond, t, returns) # get reward - y2, grad2 = get_guide_output(guide2, x, cond, t, returns, is_dynamic=True) # get state + y2, grad2 = get_guide_output(guide2, x, cond, t, returns, is_dynamic=True, + act_dim=model.action_dim) # get state grad = grad1 + scale * grad2 if scale_grad_by_std: - grad = weight * grad + grad = model_var * grad grad[t < t_stopgrad] = 0 - if returns is not None: + if model.returns_condition: # epsilon could be epsilon or x0 itself epsilon_cond = model.model(x, cond, t, returns, use_dropout=False) epsilon_uncond = model.model(x, cond, t, returns, force_dropout=True) epsilon = epsilon_uncond + model.condition_guidance_w * (epsilon_cond - epsilon_uncond) else: epsilon = model.model(x, cond, t) - epsilon += grad + epsilon -= weight * grad model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t, epsilon=epsilon) - model_std = torch.exp(0.5 * model_log_variance) + # model_std = torch.exp(0.5 * model_log_variance) noise = torch.randn_like(x) noise[t == 0] = 0 @@ -359,11 +363,15 @@ def p_losses(self, x_start, cond, target, t, returns=None): pred = self.model(x_noisy, cond, t, returns) loss = F.mse_loss(pred, target, reduction='none').mean() + with torch.no_grad(): + r0_loss = F.mse_loss(pred[:, 0], target[:,0]) log = { 'mean_pred': pred.mean().item(), 'max_pred': pred.max().item(), 'min_pred': pred.min().item(), + 'r0_loss': r0_loss.mean().item(), } + return loss, log @@ -697,21 +705,42 @@ def p_losses(self, x_start, cond, t, returns=None): def forward(self, cond, *args, **kwargs): return self.conditional_sample(cond=cond, *args, **kwargs) -class GuidenceFreeDifffuser(GaussianInvDynDiffusion): +class GuidenceFreeDifffuser(GaussianDiffusion): - def get_loss_weights(self, discount: int): - self.action_weight = 1 - dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) + def __init__( + self, + model: str, + model_cfg: dict, + horizon: int, + obs_dim: Union[int, SequenceType], + action_dim: Union[int, SequenceType], + n_timesteps: int = 1000, + predict_epsilon: bool = True, + loss_discount: float = 1.0, + clip_denoised: bool = False, + action_weight: float = 1.0, + loss_weights: dict = None, + returns_condition: bool = False, + condition_guidance_w: float = 0.1, + ): + super().__init__(model, model_cfg, horizon, obs_dim, action_dim, n_timesteps, predict_epsilon, + loss_discount, clip_denoised, action_weight, loss_weights,) + self.returns_condition = returns_condition + self.condition_guidance_w = condition_guidance_w - # decay loss with trajectory timestep: discount**t - discounts = discount ** torch.arange(self.horizon, dtype=torch.float) - discounts = discounts / discounts.mean() - loss_weights = torch.einsum('h,t->ht', discounts, dim_weights) - # Cause things are conditioned on t=0 - if self.predict_epsilon: - loss_weights[0, :] = 0 + # def get_loss_weights(self, discount: int): + # self.action_weight = 1 + # dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) - return loss_weights + # # decay loss with trajectory timestep: discount**t + # discounts = discount ** torch.arange(self.horizon, dtype=torch.float) + # discounts = discounts / discounts.mean() + # loss_weights = torch.einsum('h,t->ht', discounts, dim_weights) + # # Cause things are conditioned on t=0 + # if self.predict_epsilon: + # loss_weights[0, :] = 0 + + # return loss_weights def p_mean_variance(self, x, cond, t, epsilon): x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) @@ -745,7 +774,7 @@ def p_sample_loop(self, shape, cond, sample_fn=None, plan_size=1, **sample_kwarg x = x.gather(1, inds) return x[:,0] - + @torch.no_grad() def conditional_sample(self, cond, horizon=None, **sample_kwargs): device = self.betas.device batch_size = len(cond[0]) @@ -754,13 +783,7 @@ def conditional_sample(self, cond, horizon=None, **sample_kwargs): return self.p_sample_loop(shape, cond, **sample_kwargs) def p_losses(self, x_start, cond, t, returns=None): - noise = torch.randn_like(x_start) - - - batch_size = len(cond[0]) - mask_rand = torch.rand((batch_size,1)) - mask = torch.bernoulli(mask_rand, 0.7).to(returns.device) - returns = returns * mask + noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) x_noisy = apply_conditioning(x_noisy, cond, self.action_dim) @@ -807,14 +830,11 @@ def __init__( self.embed = nn.Sequential( nn.Linear((obs_dim * 2 + action_dim + 1) * encoder_horizon, dim * 4), - #nn.Mish(), - Mish(), + Mish(),#nn.Mish(), nn.Linear(dim * 4, dim * 4), - #nn.Mish(), - Mish(), + Mish(),#nn.Mish(), nn.Linear(dim * 4, dim * 4), - #nn.Mish(), - Mish(), + Mish(),#nn.Mish(), nn.Linear(dim * 4, dim) ) @@ -839,7 +859,8 @@ def diffuser_loss(self, x_start, cond, t, returns=None): return self.diffuser.p_losses(x_start, cond, t, returns) def pre_train_loss(self, traj, target, t, cond): - input_emb = traj.reshape(target.shape[0], -1) + encoder_traj = traj[:, :self.encoder_horizon] + input_emb = encoder_traj.reshape(target.shape[0], -1) task_idx = self.embed(input_emb) states = traj[:, :, self.action_dim:self.action_dim + self.obs_dim] @@ -852,7 +873,7 @@ def pre_train_loss(self, traj, target, t, cond): reward_loss, reward_log = self.reward_model.p_losses(input, cond, target_reward, t, task_idx) - task_idxs = task_idx.unsqueeze(1).repeat_interleave(self.encoder_horizon, dim=1) + task_idxs = task_idx.unsqueeze(1).repeat_interleave(self.horizon, dim=1) input = torch.cat([input, task_idxs], dim=-1) input = input.reshape(-1, input.shape[-1]) @@ -862,7 +883,7 @@ def pre_train_loss(self, traj, target, t, cond): return state_loss, reward_loss, reward_log - def get_eval(self, cond, id, batch_size = 1): + def get_eval(self, cond, id = None, batch_size = 1): id = torch.stack(id, dim=0) if batch_size > 1: cond = self.repeat_cond(cond, batch_size) diff --git a/ding/policy/meta_diffuser.py b/ding/policy/meta_diffuser.py index ef3670f321..cabeaa2f5d 100755 --- a/ding/policy/meta_diffuser.py +++ b/ding/policy/meta_diffuser.py @@ -143,7 +143,7 @@ class MDPolicy(Policy): update_target_freq=10, # update weight of target net target_weight=0.995, - value_step=200e3, + value_step=2e3, # dataset weight include returns include_returns=True, @@ -232,31 +232,32 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: pre_traj = torch.cat([acts, obs, rewards, next_obs], dim=-1).to(self._device) target = torch.cat([next_obs, rewards], dim=-1).to(self._device) traj = torch.cat([acts, obs], dim=-1).to(self._device) - pre_traj = pre_traj[:, :self.encoder_len] - target = pre_traj[:, :self.encoder_len] - + batch_size = len(traj) t = torch.randint(0, self.n_timesteps, (batch_size, ), device=traj.device).long() - state_loss, reward_loss, reward_log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) - loss_dict = {'dynamic_loss': state_loss, 'reward_loss': reward_loss} - total_loss = (state_loss + reward_loss) / self.gradient_accumulate_every - total_loss.backward() - - if self.gradient_steps >= self.gradient_accumulate_every: - self._pre_train_optimizer.step() - self._pre_train_optimizer.zero_grad() + if self._forward_learn_cnt < self.value_step: + state_loss, reward_loss, reward_log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) + loss_dict = {'dynamic_loss': state_loss, 'reward_loss': reward_loss} + loss_dict.update(reward_log) + total_loss = (state_loss + reward_loss) / self.gradient_accumulate_every + total_loss.backward() - task_id = self._learn_model.get_task_id(pre_traj) + task_id = self._learn_model.get_task_id(pre_traj[:, :self.encoder_len]) diffuser_loss, a0_loss = self._learn_model.diffuser_loss(traj, conds, t, task_id) loss_dict['diffuser_loss'] = diffuser_loss loss_dict['a0_loss'] = a0_loss diffuser_loss = diffuser_loss / self.gradient_accumulate_every diffuser_loss.backward() - + loss_dict['max_return'] = reward.max().item() + loss_dict['min_return'] = reward.min().item() + loss_dict['mean_return'] = reward.mean().item() if self.gradient_steps >= self.gradient_accumulate_every: self._plan_optimizer.step() self._plan_optimizer.zero_grad() + if self._forward_learn_cnt < self.value_step: + self._pre_train_optimizer.step() + self._pre_train_optimizer.zero_grad() self.gradient_steps = 1 else: self.gradient_steps += 1 @@ -289,6 +290,13 @@ def _monitor_vars_learn(self) -> List[str]: 'reward_loss', 'dynamic_loss', 'a0_loss', + 'max_return', + 'min_return', + 'mean_return', + 'mean_pred', + 'max_pred', + 'min_pred', + 'r0_loss', ] def _state_dict_learn(self) -> Dict[str, Any]: @@ -373,8 +381,6 @@ def warm_train(self, id: int): target = torch.cat([next_obs, rewards], dim=-1) batch_size = len(pre_traj) conds = {cond_ids: cond_vals} - pre_traj = pre_traj[:, :self.encoder_len] - target = pre_traj[:, :self.encoder_len] t = torch.randint(0, self.n_timesteps, (batch_size, ), device=pre_traj.device).long() state_loss, reward_loss, log = self._learn_model.pre_train_loss(pre_traj, target, t, conds) @@ -384,7 +390,7 @@ def warm_train(self, id: int): self._pre_train_optimizer.step() self.update_model_average(self._target_model, self._learn_model) - self.test_task_id = [self._target_model.get_task_id(pre_traj)[0]] * self.eval_batch_size + self.test_task_id = [self._target_model.get_task_id(pre_traj[:, :self.encoder_len])[0]] * self.eval_batch_size def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: if self.have_train: diff --git a/ding/policy/prompt_dt.py b/ding/policy/prompt_dt.py index b485e984a7..4306b01b4b 100755 --- a/ding/policy/prompt_dt.py +++ b/ding/policy/prompt_dt.py @@ -18,6 +18,10 @@ class PDTPolicy(DTPolicy): """ def default_model(self) -> Tuple[str, List[str]]: return 'dt', ['ding.model.template.decision_transformer'] + + def _init_learn(self) -> None: + super()._init_learn() + self.need_prompt = self._cfg.need_prompt def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: """ @@ -50,19 +54,28 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: traj_mask = [], [], [], [], [], [], [], [], [], [], [] for d in data: - p, timestep, s, a, r, rtg, mask = d - timesteps.append(timestep) - states.append(s) - actions.append(a) - rewards.append(r) - returns_to_go.append(rtg) - traj_mask.append(mask) - ps, pa, prtg, pt, pm = p - p_s.append(ps) - p_a.append(pa) - p_rtg.append(prtg) - p_mask.append(pm) - p_t.append(pt) + if self.need_prompt: + p, timestep, s, a, r, rtg, mask = d + timesteps.append(timestep) + states.append(s) + actions.append(a) + rewards.append(r) + returns_to_go.append(rtg) + traj_mask.append(mask) + ps, pa, prtg, pt, pm = p + p_s.append(ps) + p_a.append(pa) + p_rtg.append(prtg) + p_mask.append(pm) + p_t.append(pt) + else: + timestep, s, a, r, rtg, mask = d + timesteps.append(timestep) + states.append(s) + actions.append(a) + rewards.append(r) + returns_to_go.append(rtg) + traj_mask.append(mask) timesteps = torch.stack(timesteps, dim=0) states = torch.stack(states, dim=0) @@ -70,46 +83,32 @@ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]: rewards = torch.stack(rewards, dim=0) returns_to_go = torch.stack(returns_to_go, dim=0) traj_mask = torch.stack(traj_mask, dim=0) - p_s = torch.stack(p_s, dim=0) - p_a = torch.stack(p_a, dim=0) - p_rtg = torch.stack(p_rtg, dim=0) - p_mask = torch.stack(p_mask, dim=0) - p_t = torch.stack(p_t, dim=0) - prompt = (p_s, p_a, p_rtg, p_t, p_mask) + if self.need_prompt: + p_s = torch.stack(p_s, dim=0) + p_a = torch.stack(p_a, dim=0) + p_rtg = torch.stack(p_rtg, dim=0) + p_mask = torch.stack(p_mask, dim=0) + p_t = torch.stack(p_t, dim=0) + prompt = (p_s, p_a, p_rtg, p_t, p_mask) + else: + prompt = None # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), # and we need a 3-dim tensor if len(returns_to_go.shape) == 2: returns_to_go = returns_to_go.unsqueeze(-1) - if self._basic_discrete_env: - actions = actions.to(torch.long) - actions = actions.squeeze(-1) - action_target = torch.clone(actions).detach().to(self._device) - - if self._atari_env: - state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1, prompt=prompt - ) - else: - state_preds, action_preds, return_preds = self._learn_model.forward( - timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, prompt=prompt - ) + state_preds, action_preds, return_preds = self._learn_model.forward( + timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, prompt=prompt + ) - if self._atari_env: - action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) - else: - traj_mask = traj_mask.view(-1, ) + traj_mask = traj_mask.view(-1, ) - # only consider non padded elements - action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0] + # only consider non padded elements + action_preds = action_preds.reshape(-1, self.act_dim)[traj_mask > 0] - if self._cfg.model.continuous: - action_target = action_target.view(-1, self.act_dim)[traj_mask > 0] - action_loss = F.mse_loss(action_preds, action_target) - else: - action_target = action_target.view(-1)[traj_mask > 0] - action_loss = F.cross_entropy(action_preds, action_target) + action_target = actions.reshape(-1, self.act_dim)[traj_mask > 0] + action_loss = F.mse_loss(action_preds, action_target) self._optimizer.zero_grad() action_loss.backward() @@ -132,36 +131,57 @@ def _init_eval(self) -> None: self.test_num = self._cfg.learn.test_num self._eval_model = self._model self.eval_batch_size = self._cfg.evaluator_env_num + self.rtg_target = self._cfg.rtg_target self.task_id = None self.test_task_id = [[] for _ in range(self.eval_batch_size)] self.have_train =False + if self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) + + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: - prompt = [] - for i in range(self.eval_batch_size): - prompt.append(self.dataloader.get_prompt(is_test=True, id=self.task_id[i])) + if self.need_prompt: + p_s, p_a, p_rtg, p_t, p_mask = [], [], [], [], [] + for i in range(self.eval_batch_size): + ps, pa, prtg, pt, pm = self.dataloader.get_prompt(is_test=True, id=self.task_id[i]) + p_s.append(ps) + p_a.append(pa) + p_rtg.append(prtg) + p_mask.append(pm) + p_t.append(pt) + p_s = torch.stack(p_s, dim=0).to(self._device) + p_a = torch.stack(p_a, dim=0).to(self._device) + p_rtg = torch.stack(p_rtg, dim=0).to(self._device) + p_mask = torch.stack(p_mask, dim=0).to(self._device) + p_t = torch.stack(p_t, dim=0).to(self._device) + prompt = (p_s, p_a, p_rtg, p_t, p_mask) + else: + prompt = None - prompt = torch.tensor(prompt, device=self._device) - data_id = list(data.keys()) self._eval_model.eval() with torch.no_grad(): - if self._atari_env: - states = torch.zeros( - ( - self.eval_batch_size, - self.context_len, - ) + tuple(self.state_dim), - dtype=torch.float32, - device=self._device - ) - timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device) - else: - states = torch.zeros( - (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device - ) - timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) + states = torch.zeros( + (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device + ) + timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device) if not self._cfg.model.continuous: actions = torch.zeros( (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device @@ -174,10 +194,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device ) for i in data_id: - if self._atari_env: - self.states[i, self.t[i]] = data[i]['obs'].to(self._device) - else: - self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std + self.states[i, self.t[i]] = self.dataloader.normalize(data[i]['obs'], 'obs', self.task_id[i]) self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device) self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] @@ -192,12 +209,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: actions[i] = self.actions[i, :self.context_len] rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: - if self._atari_env: - timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( - (1, 1), dtype=torch.int64 - ).to(self._device) - else: - timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] + timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1] states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] @@ -209,13 +221,9 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: logits = act_preds[:, -1, :] if not self._cfg.model.continuous: - if self._atari_env: - probs = F.softmax(logits, dim=-1) - act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) - for i in data_id: - act[i] = torch.multinomial(probs[i], num_samples=1) - else: - act = torch.argmax(logits, axis=1).unsqueeze(1) + act = torch.argmax(logits, axis=1).unsqueeze(1) + else: + act = logits for i in data_id: self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t self.t[i] += 1 @@ -226,9 +234,68 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} - + def warm_train(self, id: int): + self.task_id = [id] * self.eval_batch_size def _reset_eval(self, data_id: Optional[List[int]] = None) -> None: if self.have_train: if self.task_id is None: - self.task_id = [0] * self.eval_batch_size \ No newline at end of file + self.task_id = [0] * self.eval_batch_size + + if data_id is None: + self.t = [0 for _ in range(self.eval_batch_size)] + self.timesteps = torch.arange( + start=0, end=self.max_eval_ep_len, step=1 + ).repeat(self.eval_batch_size, 1).to(self._device) + if not self._cfg.model.continuous: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device + ) + else: + self.actions = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), + dtype=torch.float32, + device=self._device + ) + if self._atari_env: + self.states = torch.zeros( + ( + self.eval_batch_size, + self.max_eval_ep_len, + ) + tuple(self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] + else: + self.states = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), + dtype=torch.float32, + device=self._device + ) + self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] + + self.rewards_to_go = torch.zeros( + (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device + ) + else: + for i in data_id: + self.t[i] = 0 + if not self._cfg.model.continuous: + self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device) + else: + self.actions[i] = torch.zeros( + (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device + ) + if self._atari_env: + self.states[i] = torch.zeros( + (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target + else: + self.states[i] = torch.zeros( + (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device + ) + self.running_rtg[i] = self.rtg_target / self.rtg_scale + self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device) + self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device) diff --git a/ding/torch_utils/network/diffusion.py b/ding/torch_utils/network/diffusion.py index 674e8e5b76..bab53283f1 100755 --- a/ding/torch_utils/network/diffusion.py +++ b/ding/torch_utils/network/diffusion.py @@ -139,7 +139,6 @@ def __init__(self, dim, eps=1e-5) -> None: self.b = nn.Parameter(torch.zeros(1, dim, 1)) def forward(self, x): - print('x.shape:', x.shape) var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b @@ -233,6 +232,7 @@ def __init__( self, transition_dim: int, dim: int = 32, + returns_dim: int = 1, dim_mults: SequenceType = [1, 2, 4, 8], returns_condition: bool = False, condition_dropout: float = 0.1, @@ -265,7 +265,7 @@ def __init__( act = Mish()#nn.Mish() self.time_dim = dim - self.returns_dim = dim + self.returns_dim = returns_dim self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), @@ -329,7 +329,7 @@ def __init__( nn.Conv1d(dim, transition_dim, 1), ) - def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): + def forward(self, x, cond, time, returns = None, use_dropout: bool = True, force_dropout: bool = False): """ Arguments: x (:obj:'tensor'): noise trajectory @@ -388,7 +388,7 @@ def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_d else: return x - def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): + def get_pred(self, x, cond, time, returns = None, use_dropout: bool = True, force_dropout: bool = False): # [batch, horizon, transition ] -> [batch, transition , horizon] x = x.transpose(1, 2) t = self.time_mlp(time) diff --git a/ding/worker/collector/interaction_serial_meta_evaluator.py b/ding/worker/collector/interaction_serial_meta_evaluator.py index 37bd312835..cce9cb6c3b 100755 --- a/ding/worker/collector/interaction_serial_meta_evaluator.py +++ b/ding/worker/collector/interaction_serial_meta_evaluator.py @@ -62,6 +62,7 @@ def eval( force_render: bool = False, policy_kwargs: Optional[Dict] = {}, policy_warm_func: namedtuple = None, + need_reward: bool = False, ) -> Tuple[bool, Dict[str, List]]: infos = defaultdict(list) for i in range(self.test_env_num): @@ -70,7 +71,7 @@ def eval( if policy_warm_func is not None: policy_warm_func(i) info = self.sub_eval(save_ckpt_fn, train_iter, envstep, n_episode, \ - force_render, policy_kwargs, i) + force_render, policy_kwargs, i, need_reward) for key, val in info.items(): if i == 0: info[key] = [] @@ -118,6 +119,7 @@ def sub_eval( force_render: bool = False, policy_kwargs: Optional[Dict] = {}, task_id: int = 0, + need_reward: bool = False, ) -> Tuple[bool, Dict[str, List]]: ''' Overview: @@ -146,11 +148,22 @@ def sub_eval( # force_render overwrite frequency constraint render = force_render or self._should_render(envstep, train_iter) + rewards = None + with self._timer: while not eval_monitor.is_finished(): obs = self._env.ready_obs obs = to_tensor(obs, dtype=torch.float32) + if need_reward: + for id,val in obs.items(): + if rewards is None: + reward = torch.zeros((1)) + else: + reward = torch.tensor(rewards[id], dtype=torch.float32) + obs[id] = {'obs':val, 'reward':reward} + + # update videos if render: eval_monitor.update_video(self._env.ready_imgs) @@ -167,7 +180,9 @@ def sub_eval( actions = to_ndarray(actions) timesteps = self._env.step(actions) timesteps = to_tensor(timesteps, dtype=torch.float32) + rewards = [] for env_id, t in timesteps.items(): + rewards.append(t.reward) if t.info.get('abnormal', False): # If there is an abnormal timestep, reset all the related variables(including this env). self._policy.reset([env_id]) diff --git a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py index f73e0d854f..6e0b5a7701 100755 --- a/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py +++ b/dizoo/meta_mujoco/config/walker2d_metadiffuser_config.py @@ -18,27 +18,30 @@ horizon=32, obs_dim=17, action_dim=6, - test_num=10, + test_num=1,#10, ), policy=dict( cuda=True, max_len=32, max_ep_len=200, - task_num=40, - train_num=1, + task_num=1,#40, + train_num=1,#30, obs_dim=17, act_dim=6, no_state_normalize=False, no_action_normalize=False, need_init_dataprocess=True, + need_reward=False, model=dict( diffuser_model_cfg=dict( model='DiffusionUNet1d', model_cfg=dict( transition_dim=23, dim=64, + returns_dim=1, dim_mults=[1, 4, 8], returns_condition=True, + condition_dropout=0.3, kernel_size=5, attention=False, ), @@ -47,13 +50,15 @@ action_dim=6, n_timesteps=20, predict_epsilon=False, - condition_guidance_w=1.2, + condition_guidance_w=1.6, + action_weight=10, loss_discount=1, + returns_condition=True, ), reward_cfg=dict( model='TemporalValue', model_cfg=dict( - horizon = 32, + horizon=32, transition_dim=23, dim=64, out_dim=32, @@ -70,8 +75,9 @@ loss_discount=1, ), horizon=32, + encoder_horizon=20, n_guide_steps=2, - scale=0.1, + scale=2, t_stopgrad=2, scale_grad_by_std=True, ), @@ -81,19 +87,20 @@ train_epoch=60000, gradient_accumulate_every=2, batch_size=32, + encoder_len=20, learning_rate=2e-4, discount_factor=0.99, learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )), eval_batch_size=8, warm_batch_size=32, - test_num=10, + test_num=1,#10, plan_batch_size=1, ), collect=dict(data_type='meta_traj', ), eval=dict( evaluator=dict( eval_freq=500, - test_env_num=10, + test_env_num=1, ), test_ret=0.9, ), @@ -105,7 +112,7 @@ context_len=1, stochastic_prompt=False, need_prompt=False, - test_id=[5,10,22,31,18,1,12,9,25,38], + test_id=[0],#[2,12,22,28,31,4,15,10,18,38], cond=True, env_param_path='/mnt/nfs/share/meta/walker/env_walker_param_train_task', need_next_obs=True, diff --git a/dizoo/meta_mujoco/config/walker2d_promptdt_config.py b/dizoo/meta_mujoco/config/walker2d_promptdt_config.py index 2c25d50e32..24ba95bc7c 100755 --- a/dizoo/meta_mujoco/config/walker2d_promptdt_config.py +++ b/dizoo/meta_mujoco/config/walker2d_promptdt_config.py @@ -26,8 +26,8 @@ rtg_scale=1, context_len=1, stochastic_prompt=False, - need_prompt=True, - test_id=[1],#[5,10,22,31,18,1,12,9,25,38], + need_prompt=False, + test_id=[0],#[2,12,22,28,31,4,15,10,18,38], cond=False, env_param_path='/mnt/nfs/share/meta/walker/env_walker_param_train_task', need_next_obs=False, @@ -37,17 +37,19 @@ stop_value=5000, max_len=20, max_ep_len=200, - task_num=3, - train_num=1, + task_num=1,#40, + train_num=1,#30, obs_dim=17, act_dim=6, + need_prompt=False, state_mean=None, state_std=None, no_state_normalize=False, no_action_normalize=True, need_init_dataprocess=True, + need_reward=True, evaluator_env_num=8, - rtg_target=5000, # max target return to go + rtg_target=400, # max target return to go max_eval_ep_len=1000, # max lenght of one episode wt_decay=1e-4, warmup_steps=10000, @@ -63,7 +65,7 @@ n_heads=1, drop_p=0.1, continuous=True, - use_prompt=True, + use_prompt=False, ), batch_size=32, learning_rate=1e-4, From 35e8e77b3ac6e01e51245843ea302ca3d22364bb Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Tue, 6 Feb 2024 18:42:29 +0800 Subject: [PATCH 15/16] change --- ding/entry/serial_entry_meta_offline.py | 5 +- ding/model/template/decision_transformer.py | 4 - ding/model/template/diffusion.py | 97 +++++++++++++++---- ding/torch_utils/network/diffusion.py | 24 ++--- .../interaction_serial_meta_evaluator.py | 2 +- 5 files changed, 88 insertions(+), 44 deletions(-) diff --git a/ding/entry/serial_entry_meta_offline.py b/ding/entry/serial_entry_meta_offline.py index 1759e5a9c1..46a59e0845 100755 --- a/ding/entry/serial_entry_meta_offline.py +++ b/ding/entry/serial_entry_meta_offline.py @@ -23,7 +23,8 @@ def serial_pipeline_meta_offline( ) -> 'Policy': # noqa """ Overview: - Serial pipeline entry. + Serial pipeline entry. In meta pipeline, policy is trained using multiple tasks \ + and evaluates multiple tasks specified. Evaluation value is mean of every tasks. Arguments: - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ ``str`` type means config file path. \ @@ -59,7 +60,6 @@ def serial_pipeline_meta_offline( # use the original batch size per gpu and increase learning rate # correspondingly. cfg.policy.learn.batch_size // get_world_size(), - # cfg.policy.learn.batch_size shuffle=shuffle, sampler=sampler, collate_fn=lambda x: x, @@ -96,6 +96,7 @@ def serial_pipeline_meta_offline( for epoch in range(cfg.policy.learn.train_epoch): if get_world_size() > 1: dataloader.sampler.set_epoch(epoch) + # for every train task, train policy with its dataset for i in range(cfg.policy.train_num): dataset.set_task_id(i) for train_data in dataloader: diff --git a/ding/model/template/decision_transformer.py b/ding/model/template/decision_transformer.py index 9c6ed8019e..d1cb9133a1 100755 --- a/ding/model/template/decision_transformer.py +++ b/ding/model/template/decision_transformer.py @@ -331,11 +331,7 @@ def forward( (prompt_returns_embeddings, prompt_state_embeddings, prompt_action_embeddings), dim=1 ).permute(0, 2, 1, 3).reshape(prompt_states.shape[0], 3 * prompt_seq_length, self.h_dim) - # prompt_stacked_attention_mask = torch.stack( - # (prompt_attention_mask, prompt_attention_mask, prompt_attention_mask), dim=1 - # ).permute(0, 2, 1).reshape(prompt_states.shape[0], 3 * prompt_seq_length h = torch.cat((prompt_stacked_inputs, h), dim=1) - # stacked_attention_mask = torch.cat((prompt_stacked_attention_mask, stacked_attention_mask), dim=1) # transformer and prediction diff --git a/ding/model/template/diffusion.py b/ding/model/template/diffusion.py index ff703a92cb..0d06f94796 100755 --- a/ding/model/template/diffusion.py +++ b/ding/model/template/diffusion.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType from ding.torch_utils.network.diffusion import extract, cosine_beta_schedule, apply_conditioning, \ - DiffusionUNet1d, TemporalValue, Mish + DiffusionUNet1d, TemporalValue Sample = namedtuple('Sample', 'trajectories values chains') @@ -56,6 +56,15 @@ def n_step_guided_p_sample( n_guide_steps=1, scale_grad_by_std=True, ): + """ + Overview: + Guidance fn for Diffusion + Arguments: + - model (obj: 'class') diffusion model + - x (obj: 'tensor') input for guidance + - cond (obj: 'tensor') cond of input + - guide (obj: 'class') guide function + """ model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) model_std = torch.exp(0.5 * model_log_variance) model_var = torch.exp(model_log_variance) @@ -94,6 +103,18 @@ def free_guidance_sample( scale_grad_by_std=True, ): + """ + Overview: + Guidance fn for MetaDiffusion + Arguments: + - model (obj: 'class') diffusion model + - x (obj: 'tensor') input for guidance + - cond (obj: 'tensor') cond of input + - guide1 (obj: 'class') guide function. In MetaDiffusion is reward function + - guide2 (obj: 'class') guide function. In MetaDiffusion is dynamic function + - returns (obj: 'tensor') for MetaDiffusion, it is id for task. + + """ weight = extract(model.sqrt_one_minus_alphas_cumprod, t, x.shape) model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape) model_std = torch.exp(0.5 * model_log_variance) @@ -706,6 +727,24 @@ def forward(self, cond, *args, **kwargs): return self.conditional_sample(cond=cond, *args, **kwargs) class GuidenceFreeDifffuser(GaussianDiffusion): + """ + Overview: + Gaussian diffusion model with guidence + Arguments: + - model (:obj:`str`): type of model + - model_cfg (:obj:'dict') config of model + - horizon (:obj:`int`): horizon of trajectory + - obs_dim (:obj:`int`): Dim of the ovservation + - action_dim (:obj:`int`): Dim of the ation + - n_timesteps (:obj:`int`): Number of timesteps + - predict_epsilon (:obj:'bool'): Whether predict epsilon + - loss_discount (:obj:'float'): discount of loss + - clip_denoised (:obj:'bool'): Whether use clip_denoised + - action_weight (:obj:'float'): weight of action + - loss_weights (:obj:'dict'): weight of loss + - returns_condition (:obj:'bool') whether use additional condition + - condition_guidance_w (:obj:'float') guidance weight + """ def __init__( self, @@ -728,20 +767,6 @@ def __init__( self.returns_condition = returns_condition self.condition_guidance_w = condition_guidance_w - # def get_loss_weights(self, discount: int): - # self.action_weight = 1 - # dim_weights = torch.ones(self.transition_dim, dtype=torch.float32) - - # # decay loss with trajectory timestep: discount**t - # discounts = discount ** torch.arange(self.horizon, dtype=torch.float) - # discounts = discounts / discounts.mean() - # loss_weights = torch.einsum('h,t->ht', discounts, dim_weights) - # # Cause things are conditioned on t=0 - # if self.predict_epsilon: - # loss_weights[0, :] = 0 - - # return loss_weights - def p_mean_variance(self, x, cond, t, epsilon): x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon) @@ -808,7 +833,19 @@ def p_losses(self, x_start, cond, t, returns=None): @MODEL_REGISTRY.register('metadiffuser') class MetaDiffuser(nn.Module): - + """ + Overview: + MetaDiffusion model + Arguments: + - dim (:obj:`int`): dim of emb and dynamic model + - obs_dim (:obj:`int`): Dim of the ovservation + - action_dim (:obj:`int`): Dim of the ation + - reward_cfg (:obj:'dict') config of reward model + - diffuser_model_cfg (:obj:'dict') config of diffuser_model + - horizon (:obj:`int`): horizon of trajectory + - encoder_horizon (:obj:`int`): horizon of emb model + - sample_kwargs : config of sample function + """ def __init__( self, dim: int, @@ -830,11 +867,11 @@ def __init__( self.embed = nn.Sequential( nn.Linear((obs_dim * 2 + action_dim + 1) * encoder_horizon, dim * 4), - Mish(),#nn.Mish(), + nn.Mish(), nn.Linear(dim * 4, dim * 4), - Mish(),#nn.Mish(), + nn.Mish(), nn.Linear(dim * 4, dim * 4), - Mish(),#nn.Mish(), + nn.Mish(), nn.Linear(dim * 4, dim) ) @@ -851,6 +888,12 @@ def __init__( self.diffuser = GuidenceFreeDifffuser(**diffuser_model_cfg) def get_task_id(self, traj): + """ + Overview: + get task id for trajectory + Arguments: + - traj (:obj:'tensor') trajectory of env + """ input_emb = traj.reshape(traj.shape[0], -1) task_idx = self.embed(input_emb) return task_idx @@ -859,6 +902,15 @@ def diffuser_loss(self, x_start, cond, t, returns=None): return self.diffuser.p_losses(x_start, cond, t, returns) def pre_train_loss(self, traj, target, t, cond): + """ + Overview: + train dynamic, reward and embed model. + Arguments: + - traj (:obj:'tensor') traj for dataset, include: obs, next_obs, action, reward + - target (:obj:'tensor') target obs and rerward + - t (:obj:'int') step + - cond (:obj:'tensor') condition of input + """ encoder_traj = traj[:, :self.encoder_horizon] input_emb = encoder_traj.reshape(target.shape[0], -1) task_idx = self.embed(input_emb) @@ -884,6 +936,13 @@ def pre_train_loss(self, traj, target, t, cond): return state_loss, reward_loss, reward_log def get_eval(self, cond, id = None, batch_size = 1): + """ + Overview: + get action + Arguments: + - cond (:obj:'tensor') condition for sample + - id (:obj:'tensor') id for task. + """ id = torch.stack(id, dim=0) if batch_size > 1: cond = self.repeat_cond(cond, batch_size) diff --git a/ding/torch_utils/network/diffusion.py b/ding/torch_utils/network/diffusion.py index bab53283f1..c26e3f240d 100755 --- a/ding/torch_utils/network/diffusion.py +++ b/ding/torch_utils/network/diffusion.py @@ -44,14 +44,6 @@ def apply_conditioning(x, conditions, action_dim, mask = None): x[:, t, action_dim:] = val.clone() return x -class Mish(nn.Module): - def __init__(self): - super().__init__() - - def forward(self,x): - x = x * (torch.tanh(F.softplus(x))) - return x - class DiffusionConv1d(nn.Module): def __init__( @@ -203,7 +195,7 @@ def __init__( ) -> None: super().__init__() if mish: - act = Mish()#nn.Mish() + act = nn.Mish() else: act = nn.SiLU() self.blocks = nn.ModuleList( @@ -262,7 +254,7 @@ def __init__( act = nn.SiLU() else: mish = True - act = Mish()#nn.Mish() + act = nn.Mish() self.time_dim = dim self.returns_dim = returns_dim @@ -460,8 +452,7 @@ def __init__( self.time_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), - #nn.Mish(), - Mish(), + nn.Mish(), nn.Linear(dim * 4, dim), ) if returns_condition: @@ -470,15 +461,13 @@ def __init__( self.returns_mlp = nn.Sequential( SinusoidalPosEmb(dim), nn.Linear(dim, dim * 4), - #nn.Mish(), - Mish(), + nn.Mish(), nn.Linear(dim * 4, dim), ) else: self.returns_mlp = nn.Sequential( nn.Linear(dim, dim * 4), - #nn.Mish(), - Mish(), + nn.Mish(), nn.Linear(dim * 4, dim), ) self.blocks = nn.ModuleList([]) @@ -511,8 +500,7 @@ def __init__( fc_dim = mid_dim_3 * max(horizon, 1) self.final_block = nn.Sequential( nn.Linear(fc_dim + time_dim, fc_dim // 2), - #nn.Mish(), - Mish(), + nn.Mish(), nn.Linear(fc_dim // 2, out_dim), ) diff --git a/ding/worker/collector/interaction_serial_meta_evaluator.py b/ding/worker/collector/interaction_serial_meta_evaluator.py index cce9cb6c3b..0cc3beb738 100755 --- a/ding/worker/collector/interaction_serial_meta_evaluator.py +++ b/ding/worker/collector/interaction_serial_meta_evaluator.py @@ -19,7 +19,7 @@ class InteractionSerialMetaEvaluator(InteractionSerialEvaluator): Interaction serial evaluator class, policy interacts with env. This class evaluator algorithm with test environment list. Interfaces: - __init__, reset, reset_policy, reset_env, close, should_eval, eval + ``__init__``, reset, reset_policy, reset_env, close, should_eval, eval Property: env, policy """ From 9b611db4037c34cb3dde21a1c14b51111a2b25fa Mon Sep 17 00:00:00 2001 From: Super1ce <278042904@qq.com> Date: Wed, 7 Feb 2024 17:21:54 +0800 Subject: [PATCH 16/16] add notion --- ding/utils/data/dataset.py | 102 +++++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 33 deletions(-) diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index c006d68c51..aa6a16368f 100755 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1104,6 +1104,28 @@ def __getitem__(self, idx, eps=1e-4): @DATASET_REGISTRY.register('meta_traj') class MetaTraj(Dataset): + """ + Overview: + Dataset for Meta policy + Arguments: + - cfg (:obj:'dict'): cfg of policy + Key: + - dataset.data_dir_prefix (:obj:'str'): dataset path + - dataset.env_param_path (:obj:'str'): environment params path + - dataset.rtg_scale (:obj:'float'): return to go scale + - dataset.context_len (:obj:'int'): context len + - no_state_normalize (:obj:'bool'): whether normalize state + - no_action_normalize (:obj:'bool'): whether normalize action + - task_num (:obj:'int'): nums of meta tasks + - policy.max_len (:obj:'int'): max len of trajectory + - dataset.stochastic_prompt (:obj:'bool'): select max return prompt or random prompt + - dataset.need_prompt (:obj:'bool'): whether need prompt + - dataset.need_prompt (:obj:'list'): id of test evnironment + - dataset.need_next_obs (:obj:'bool'): whether need next_obs, if need, traj len = max_len + 1 + - dataset.cond (:obj:'bool'): whether add condition + Returns: + return trajectory dataset for Meta Policy. + """ def __init__(self, cfg): dataset_path = cfg.dataset.data_dir_prefix env_param_path = cfg.dataset.env_param_path @@ -1160,8 +1182,6 @@ def __init__(self, cfg): id = 0 for file_path in file_paths: - if self.need_prompt: - returns = [] with h5py.File(file_path, 'r') as hf: N = hf['rewards'].shape[0] path = [] @@ -1169,32 +1189,51 @@ def __init__(self, cfg): for k in ['obs', 'actions', 'rewards', 'terminals', 'mask']: data_[k].append(hf[k][i]) path.append(data_) - if self.need_prompt: - returns.append(hf['returns'][0][i]) data_ = collections.defaultdict(list) + + if self.need_prompt: + returns = np.sum(np.array(hf['rewards']), axis=1) state_mean, state_std = hf['state_mean'][:], hf['state_std'][:] if not self.no_action_normalize: action_mean, action_std = hf['action_mean'][:], hf['action_std'][:] - if id not in self.test_id: - self.traj.append(path) - self.state_means.append(state_mean) - self.state_stds.append(state_std) - if not self.no_action_normalize: - self.action_means.append(action_mean) - self.action_stds.append(action_std) - if self.need_prompt: - self.returns.append(returns) - else: - self.test_traj.append(path) - self.test_state_means.append(state_mean) - self.test_state_stds.append(state_std) - if not self.no_action_normalize: - self.test_action_means.append(action_mean) - self.test_action_stds.append(action_std) - if self.need_prompt: - self.test_returns.append(returns) + # if id not in self.test_id: + # self.traj.append(path) + # self.state_means.append(state_mean) + # self.state_stds.append(state_std) + # if not self.no_action_normalize: + # self.action_means.append(action_mean) + # self.action_stds.append(action_std) + # if self.need_prompt: + # self.returns.append(returns) + # else: + # self.test_traj.append(path) + # self.test_state_means.append(state_mean) + # self.test_state_stds.append(state_std) + # if not self.no_action_normalize: + # self.test_action_means.append(action_mean) + # self.test_action_stds.append(action_std) + # if self.need_prompt: + # self.test_returns.append(returns) + + self.traj.append(path) + self.state_means.append(state_mean) + self.state_stds.append(state_std) + if not self.no_action_normalize: + self.action_means.append(action_mean) + self.action_stds.append(action_std) + if self.need_prompt: + self.returns.append(returns) + + self.test_traj.append(path) + self.test_state_means.append(state_mean) + self.test_state_stds.append(state_std) + if not self.no_action_normalize: + self.test_action_means.append(action_mean) + self.test_action_stds.append(action_std) + if self.need_prompt: + self.test_returns.append(returns) id += 1 self.params = [] @@ -1205,17 +1244,15 @@ def __init__(self, cfg): if self.need_prompt: self.prompt_trajectories = [] for i in range(len(self.traj)): - idx = np.argsort(self.returns) # lowest to highest + idx = np.argsort(self.returns[i]) # lowest to highest # set 10% highest traj as prompt idx = idx[-(len(self.traj[i]) // 20) : ] - self.prompt_trajectories.append(np.array(self.traj[i])[idx]) self.test_prompt_trajectories = [] for i in range(len(self.test_traj)): - idx = np.argsort(self.test_returns) + idx = np.argsort(self.test_returns[i]) idx = idx[-(len(self.test_traj[i]) // 20) : ] - self.test_prompt_trajectories.append(np.array(self.test_traj[i])[idx]) self.set_task_id(0) @@ -1226,7 +1263,7 @@ def __len__(self): def get_prompt(self, sample_size=1, is_test=False, id=0): if not is_test: batch_inds = np.random.choice( - np.arange(len(self.prompt_trajectories[self.task_id])), + np.arange(len(self.prompt_trajectories[id])), size=sample_size, replace=True, # p=p_sample, # reweights so we sample according to timesteps @@ -1242,13 +1279,13 @@ def get_prompt(self, sample_size=1, is_test=False, id=0): ) prompt_trajectories = self.test_prompt_trajectories[id] sorted_inds = np.argsort(self.test_returns[id]) - + if self.stochastic_prompt: - traj = prompt_trajectories[batch_inds[sample_size]][0] # random select traj + traj = prompt_trajectories[batch_inds[sample_size]][0,0] # random select traj else: - traj = prompt_trajectories[sorted_inds[-sample_size]][0] # select the best traj with highest rewards + traj = prompt_trajectories[sorted_inds[-sample_size]][0,0] # select the best traj with highest rewards # traj = prompt_trajectories[i] - si = max(0, traj['rewards'][0].shape[1] - self.max_len -1) # select the last traj with length max_len + si = max(0, len(traj['rewards'][0]) - self.max_len -1) # select the last traj with length max_len # get sequences from dataset @@ -1299,7 +1336,6 @@ def unnormalize(self, data: np.array, type: str, task_id: int): # get warm start data def get_pretrain_data(self, task_id: int, batch_size: int): # get warm data - print('task_id:',task_id) trajs = self.test_traj[task_id] batch_idx = np.random.choice( np.arange(len(trajs)), @@ -1383,7 +1419,7 @@ def __getitem__(self, index): mask = torch.from_numpy(mask).to(dtype=torch.long) if self.need_prompt: - prompt = self.get_prompt(self.task_id) + prompt = self.get_prompt(id=self.task_id) return prompt, timesteps, s, a, r, rtg, mask elif self.cond: cond_id = 0