From e98170694784b5073b39ac1872d94c13d6088ff6 Mon Sep 17 00:00:00 2001 From: FeSens Date: Fri, 23 Feb 2024 00:08:51 -0300 Subject: [PATCH 1/5] add frame dimension --- models.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/models.py b/models.py index c90eeba7..186c51e8 100644 --- a/models.py +++ b/models.py @@ -144,11 +144,12 @@ def forward(self, x, c): class DiT(nn.Module): """ - Diffusion model with a Transformer backbone. - """ + Diffusion model with a Transformer backbone. + """ def __init__( self, input_size=32, + num_frames=16, # Number of frames in the diffusion process patch_size=2, in_channels=4, hidden_size=1152, @@ -165,13 +166,14 @@ def __init__( self.out_channels = in_channels * 2 if learn_sigma else in_channels self.patch_size = patch_size self.num_heads = num_heads + self.num_frames = num_frames self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) self.t_embedder = TimestepEmbedder(hidden_size) self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches * num_frames, hidden_size), requires_grad=False) self.blocks = nn.ModuleList([ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) @@ -189,7 +191,7 @@ def _basic_init(module): self.apply(_basic_init) # Initialize (and freeze) pos_embed by sin-cos embedding: - pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int((self.x_embedder.num_patches * self.num_frames) ** 0.5)) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): @@ -217,33 +219,37 @@ def _basic_init(module): def unpatchify(self, x): """ - x: (N, T, patch_size**2 * C) + x: (N, T * num_frames, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels p = self.x_embedder.patch_size[0] - h = w = int(x.shape[1] ** 0.5) - assert h * w == x.shape[1] + h = w = int((x.shape[1] / self.num_frames) ** 0.5) + assert h * w == (x.shape[1] / self.num_frames) - x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) - x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + x = x.reshape(shape=(x.shape[0], self.num_frames, h, w, p, p, c)) + x = torch.einsum('nfhwpqc->nfchpwq', x) + imgs = x.reshape(shape=(x.shape[0], self.num_frames, c, h * p, h * p)) return imgs def forward(self, x, t, y): """ Forward pass of DiT. - x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + x: (N, F, C, H, W) tensor of spatial inputs (images or latent representations of images) t: (N,) tensor of diffusion timesteps y: (N,) tensor of class labels """ - x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + N, F, C, H, W = x.shape + + x = self.x_embedder(x.reshape(N*F, C, H, W)).reshape(N, int(H * W / self.patch_size ** 2) * F , -1) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + # x should have 4096 tokens for 256x256 images, and 16 frames + t = self.t_embedder(t) # (N, D) y = self.y_embedder(y, self.training) # (N, D) c = t + y # (N, D) for block in self.blocks: x = block(x, c) # (N, T, D) - x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) x = self.unpatchify(x) # (N, out_channels, H, W) return x From 28d47c947e31b28e8bc8f98340decb72a93375c5 Mon Sep 17 00:00:00 2001 From: FeSens Date: Fri, 23 Feb 2024 00:11:46 -0300 Subject: [PATCH 2/5] fix traling spaces --- models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 186c51e8..634339a7 100644 --- a/models.py +++ b/models.py @@ -144,8 +144,8 @@ def forward(self, x, c): class DiT(nn.Module): """ - Diffusion model with a Transformer backbone. - """ + Diffusion model with a Transformer backbone. + """ def __init__( self, input_size=32, From 3a8432d0ce0af2f3a38a5702489820dc15d8c268 Mon Sep 17 00:00:00 2001 From: FeSens Date: Fri, 23 Feb 2024 00:15:55 -0300 Subject: [PATCH 3/5] remove comments --- models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models.py b/models.py index 634339a7..3e1d115a 100644 --- a/models.py +++ b/models.py @@ -242,7 +242,6 @@ def forward(self, x, t, y): N, F, C, H, W = x.shape x = self.x_embedder(x.reshape(N*F, C, H, W)).reshape(N, int(H * W / self.patch_size ** 2) * F , -1) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 - # x should have 4096 tokens for 256x256 images, and 16 frames t = self.t_embedder(t) # (N, D) y = self.y_embedder(y, self.training) # (N, D) From 03f26be8ac43f7990b48f352ad9f11bcb04b3ff1 Mon Sep 17 00:00:00 2001 From: FeSens Date: Wed, 28 Feb 2024 01:09:25 -0300 Subject: [PATCH 4/5] add time dimension in a simple way --- .gitignore | 1 + models.py | 22 ++++++++++++++-------- 2 files changed, 15 insertions(+), 8 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..5e512399 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +ucf101/ \ No newline at end of file diff --git a/models.py b/models.py index 3e1d115a..12877978 100644 --- a/models.py +++ b/models.py @@ -14,7 +14,7 @@ import numpy as np import math from timm.models.vision_transformer import PatchEmbed, Attention, Mlp - +from einops import rearrange def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -219,17 +219,17 @@ def _basic_init(module): def unpatchify(self, x): """ - x: (N, T * num_frames, patch_size**2 * C) + x: (N, T, patch_size**2 * C) imgs: (N, H, W, C) """ c = self.out_channels p = self.x_embedder.patch_size[0] - h = w = int((x.shape[1] / self.num_frames) ** 0.5) - assert h * w == (x.shape[1] / self.num_frames) + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], self.num_frames, h, w, p, p, c)) - x = torch.einsum('nfhwpqc->nfchpwq', x) - imgs = x.reshape(shape=(x.shape[0], self.num_frames, c, h * p, h * p)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) return imgs def forward(self, x, t, y): @@ -241,7 +241,9 @@ def forward(self, x, t, y): """ N, F, C, H, W = x.shape - x = self.x_embedder(x.reshape(N*F, C, H, W)).reshape(N, int(H * W / self.patch_size ** 2) * F , -1) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + x = rearrange(x, 'n f c h w -> (n f) c h w') # Put frames in batch dimension + x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2 + x = rearrange(x, '(n f) t d -> n (f t) d', f=F) + self.pos_embed # unpack frames from batch dimension and put them in token dimension t = self.t_embedder(t) # (N, D) y = self.y_embedder(y, self.training) # (N, D) @@ -249,7 +251,11 @@ def forward(self, x, t, y): for block in self.blocks: x = block(x, c) # (N, T, D) x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + + x = rearrange(x, 'n (f t) l -> (n f) t l', f=F) # Put frames in batch dimensions again to unpatch x = self.unpatchify(x) # (N, out_channels, H, W) + x = rearrange(x, '(n f) c h w -> n f c h w', f=F) # after unpach, remove frames from batch dimension + return x def forward_with_cfg(self, x, t, y, cfg_scale): From f3f06a34ff5ff4353824f5b93b1edca58fe3ed77 Mon Sep 17 00:00:00 2001 From: FeSens Date: Wed, 28 Feb 2024 01:10:08 -0300 Subject: [PATCH 5/5] add time dimension to training --- diffusion/gaussian_diffusion.py | 6 ++--- train.py | 40 ++++++++++++++++++++++++--------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/diffusion/gaussian_diffusion.py b/diffusion/gaussian_diffusion.py index ccbcefec..ad604f71 100644 --- a/diffusion/gaussian_diffusion.py +++ b/diffusion/gaussian_diffusion.py @@ -750,9 +750,9 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE, ]: - B, C = x_t.shape[:2] - assert model_output.shape == (B, C * 2, *x_t.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) # Learn the variance using the variational bound, but don't let # it affect our mean prediction. frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) diff --git a/train.py b/train.py index 7cfee808..82720d7d 100644 --- a/train.py +++ b/train.py @@ -30,7 +30,8 @@ from models import DiT_models from diffusion import create_diffusion from diffusers.models import AutoencoderKL - +from einops import rearrange +import torchvision ################################################################################# # Training Helper Functions # @@ -155,13 +156,28 @@ def main(args): opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0) # Setup data: - transform = transforms.Compose([ - transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) - ]) - dataset = ImageFolder(args.data_path, transform=transform) + # transform = transforms.Compose([ + # transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), + # transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), + # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + # ]) + + transform = torchvision.transforms.Compose([ + transforms.Resize(args.image_size, args.image_size), + transforms.ToTensor() + ]) + + #dataset = ImageFolder(args.data_path, transform=transform) + dataset = torchvision.datasets.UCF101( + args.data_path, + f"{args.data_path}/ucfTrainTestlist", + 16, + 5, + output_format="FCHW", + transform=transform + ) + sampler = DistributedSampler( dataset, num_replicas=dist.get_world_size(), @@ -195,12 +211,16 @@ def main(args): for epoch in range(args.epochs): sampler.set_epoch(epoch) logger.info(f"Beginning epoch {epoch}...") - for x, y in loader: + for x, _, y in loader: x = x.to(device) y = y.to(device) with torch.no_grad(): # Map input images to latent space + normalize latents: + b = x.size(0) + x = rearrange(x, 'b f c h w -> (b f) c h w').contiguous() # this borrows from https://github.com/Vchitect/Latte/blob/main/train.py x = vae.encode(x).latent_dist.sample().mul_(0.18215) + x = rearrange(x, '(b f) c h w -> b f c h w', b=b).contiguous() # this borrows from https://github.com/Vchitect/Latte/blob/main/train.py + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) model_kwargs = dict(y=y) loss_dict = diffusion.training_losses(model, x, t, model_kwargs) @@ -259,7 +279,7 @@ def main(args): parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) parser.add_argument("--num-classes", type=int, default=1000) parser.add_argument("--epochs", type=int, default=1400) - parser.add_argument("--global-batch-size", type=int, default=256) + parser.add_argument("--global-batch-size", type=int, default=8) parser.add_argument("--global-seed", type=int, default=0) parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training parser.add_argument("--num-workers", type=int, default=4)