diff --git a/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py b/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py index d29da269b..7afec16a5 100644 --- a/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py +++ b/opensora/models/diffusion/diffusion/gaussian_diffusion_t2v.py @@ -19,6 +19,12 @@ def mean_flat(tensor): """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) +def mean_flat_reweight(tensor, weights): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) * weights + class ModelMeanType(enum.Enum): """ @@ -726,7 +732,7 @@ def _vb_terms_bpd( output = th.where((t == 0), decoder_nll, kl) return {"output": output, "pred_xstart": out["pred_xstart"]} - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mse_loss_weights=None): """ Compute training losses for a single timestep. :param model: the model to evaluate loss on. @@ -801,7 +807,10 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): ModelMeanType.EPSILON: noise, }[self.model_mean_type] assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) + if mse_loss_weights is not None: + terms["mse"] = mean_flat_reweight((target - model_output) ** 2, mse_loss_weights) + else: + terms["mse"] = mean_flat((target - model_output) ** 2) if "vb" in terms: terms["loss"] = terms["mse"] + terms["vb"] else: diff --git a/opensora/train/train_t2v_t5_feature.py b/opensora/train/train_t2v_t5_feature.py index a8f3543e0..e5420b347 100644 --- a/opensora/train/train_t2v_t5_feature.py +++ b/opensora/train/train_t2v_t5_feature.py @@ -98,6 +98,30 @@ def generate_timestep_weights(args, num_timesteps): return weights +def compute_snr(timesteps, alphas_cumprod): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + ################################################################################# # Training Loop # @@ -456,7 +480,16 @@ def load_model_hook(models, input_dir): model_kwargs = dict(encoder_hidden_states=cond, attention_mask=attn_mask, encoder_attention_mask=cond_mask, use_image_num=args.use_image_num) t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=accelerator.device) - loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + if args.snr_gamma is not None: + snr = compute_snr(t, diffusion.alphas_cumprod) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(t)], dim=1).min(dim=1)[0] / snr + ) + + loss_dict = diffusion.training_losses(model, x, t, model_kwargs, mse_loss_weights=mse_loss_weights) + else: + loss_dict = diffusion.training_losses(model, x, t, model_kwargs) + loss = loss_dict["loss"].mean() # Gather the losses across all processes for logging (if we use distributed training). @@ -746,7 +779,7 @@ def load_model_hook(models, input_dir): parser.add_argument( "--snr_gamma", type=float, - default=None, + default=5.0, help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " "More details here: https://arxiv.org/abs/2303.09556.", )