Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add frame dimension #65

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ucf101/
6 changes: 3 additions & 3 deletions diffusion/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,9 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = th.split(model_output, C, dim=1)
B, F, C = x_t.shape[:3]
assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
model_output, model_var_values = th.split(model_output, C, dim=2)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
Expand Down
23 changes: 17 additions & 6 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp

from einops import rearrange

def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
Expand Down Expand Up @@ -149,6 +149,7 @@ class DiT(nn.Module):
def __init__(
self,
input_size=32,
num_frames=16, # Number of frames in the diffusion process
patch_size=2,
in_channels=4,
hidden_size=1152,
Expand All @@ -165,13 +166,14 @@ def __init__(
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.num_frames = num_frames

self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches * num_frames, hidden_size), requires_grad=False)

self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
Expand All @@ -189,7 +191,7 @@ def _basic_init(module):
self.apply(_basic_init)

# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int((self.x_embedder.num_patches * self.num_frames) ** 0.5))
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
Expand Down Expand Up @@ -233,18 +235,27 @@ def unpatchify(self, x):
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
x: (N, F, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
N, F, C, H, W = x.shape

x = rearrange(x, 'n f c h w -> (n f) c h w') # Put frames in batch dimension
x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2
x = rearrange(x, '(n f) t d -> n (f t) d', f=F) + self.pos_embed # unpack frames from batch dimension and put them in token dimension

t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)

x = rearrange(x, 'n (f t) l -> (n f) t l', f=F) # Put frames in batch dimensions again to unpatch
x = self.unpatchify(x) # (N, out_channels, H, W)
x = rearrange(x, '(n f) c h w -> n f c h w', f=F) # after unpach, remove frames from batch dimension

return x

def forward_with_cfg(self, x, t, y, cfg_scale):
Expand Down
40 changes: 30 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from models import DiT_models
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL

from einops import rearrange
import torchvision

#################################################################################
# Training Helper Functions #
Expand Down Expand Up @@ -155,13 +156,28 @@ def main(args):
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)

# Setup data:
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])
dataset = ImageFolder(args.data_path, transform=transform)
# transform = transforms.Compose([
# transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
# ])

transform = torchvision.transforms.Compose([
transforms.Resize(args.image_size, args.image_size),
transforms.ToTensor()
])

#dataset = ImageFolder(args.data_path, transform=transform)
dataset = torchvision.datasets.UCF101(
args.data_path,
f"{args.data_path}/ucfTrainTestlist",
16,
5,
output_format="FCHW",
transform=transform
)

sampler = DistributedSampler(
dataset,
num_replicas=dist.get_world_size(),
Expand Down Expand Up @@ -195,12 +211,16 @@ def main(args):
for epoch in range(args.epochs):
sampler.set_epoch(epoch)
logger.info(f"Beginning epoch {epoch}...")
for x, y in loader:
for x, _, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
# Map input images to latent space + normalize latents:
b = x.size(0)
x = rearrange(x, 'b f c h w -> (b f) c h w').contiguous() # this borrows from https://github.com/Vchitect/Latte/blob/main/train.py
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
x = rearrange(x, '(b f) c h w -> b f c h w', b=b).contiguous() # this borrows from https://github.com/Vchitect/Latte/blob/main/train.py

t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
model_kwargs = dict(y=y)
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
Expand Down Expand Up @@ -259,7 +279,7 @@ def main(args):
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--epochs", type=int, default=1400)
parser.add_argument("--global-batch-size", type=int, default=256)
parser.add_argument("--global-batch-size", type=int, default=8)
parser.add_argument("--global-seed", type=int, default=0)
parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema") # Choice doesn't affect training
parser.add_argument("--num-workers", type=int, default=4)
Expand Down