Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]: support navit #321

Open
wants to merge 1 commit into
base: navit
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,45 +0,0 @@
ucf101_stride4x4x4
__pycache__
*.mp4
.ipynb_checkpoints
*.pth
UCF-101/
results/
vae
build/
opensora.egg-info/
wandb/
.idea
*.ipynb
*.jpg
*.mp3
*.safetensors
*.mp4
*.png
*.gif
*.pth
*.pt
cache_dir/
wandb/
test*
sample_video*
sample_image*
512*
720*
1024*
debug*
private*
caption*
*deepspeed*
revised*
129f*
all*
read*
YSH*
*pick*
*ysh*
hw*
257f*
513f*
taming*
221hw*
484 changes: 74 additions & 410 deletions README.md

Large diffs are not rendered by default.

410 changes: 410 additions & 0 deletions README_original.md

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion opensora/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torchvision.transforms import Lambda

from .t2v_datasets import T2V_dataset
from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo
from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo, RandomResize


ae_norm = {
Expand Down Expand Up @@ -62,4 +62,12 @@ def getdataset(args):
])
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir)
return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
elif args.dataset == 't2v_navit':
transform = transforms.Compose([
ToTensorVideo(),
RandomResize(),
norm_fun
])
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir)
return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
raise NotImplementedError(args.dataset)
31 changes: 31 additions & 0 deletions opensora/dataset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,37 @@ def __call__(self, clip):

def __repr__(self) -> str:
return f"{self.__class__.__name__}(p={self.p})"

class RandomResize:
"""
Resize the video randomly in (64, orig_size)
"""
def __init__(
self,
interpolation_mode="bilinear",
):
self.interpolation_mode = interpolation_mode

def __call__(self, clip):
"""
Args:
clip (torch.tensor): Video clip to be resized. Size is (T, C, H, W)
Returns:
torch.tensor: Resized video clip.
size is (T, C, new_h, img_w)
"""
img_h,img_w = clip.shape[-2:]
if img_h > 64:
new_h = random.randint(64, img_h)
if img_w > 64:
new_w = random.randint(64, img_w)

clip_resized = resize(clip, target_size=(new_h,new_w),
interpolation_mode=self.interpolation_mode)
return clip_resized

def __repr__(self) -> str:
return f"{self.__class__.__name__}(patch_size={self.patch_size}, interpolation_mode={self.interpolation_mode}"


# ------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions opensora/models/ae/videobase/modules/quant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

import torch
import torch.nn as nn
import torch.distributed as dist
Expand Down
2 changes: 2 additions & 0 deletions opensora/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from .latte.modeling_latte import Latte_models
from .latte.modeling_latte_navit import Latte_navit_models

Diffusion_models = {}
Diffusion_models.update(Latte_models)
Diffusion_models.update(Latte_navit_models)


41 changes: 40 additions & 1 deletion opensora/models/diffusion/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py

from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T
from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T, NaViTSpacedDiffusion_T


def create_diffusion(
Expand Down Expand Up @@ -85,3 +85,42 @@ def create_diffusion_T(
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)

def create_diffusion_navit(
timestep_respacing,
noise_schedule="linear",
use_kl=False,
sigma_small=False,
predict_xstart=False,
learn_sigma=False, # NaViT only supports learn_sigma=False
rescale_learned_sigmas=False,
diffusion_steps=1000
):
from . import gaussian_diffusion_t2v as gd
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if timestep_respacing is None or timestep_respacing == "":
timestep_respacing = [diffusion_steps]
return NaViTSpacedDiffusion_T(
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
betas=betas,
model_mean_type=(
gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
),
model_var_type=(
(
gd.ModelVarType.FIXED_LARGE
if not sigma_small
else gd.ModelVarType.FIXED_SMALL
)
if not learn_sigma
else gd.ModelVarType.LEARNED_RANGE
),
loss_type=loss_type
# rescale_timesteps=rescale_timesteps,
)
65 changes: 63 additions & 2 deletions opensora/models/diffusion/diffusion/respace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
import torch as th

from .gaussian_diffusion import GaussianDiffusion
from .gaussian_diffusion_t2v import GaussianDiffusion_T

from opensora.models.diffusion.latte.modeling_latte_navit import pack_target_as
from .gaussian_diffusion_t2v import (
GaussianDiffusion_T,
LossType,
ModelMeanType,
ModelVarType,
discretized_gaussian_log_likelihood,
mean_flat,
normal_kl,
)


def space_timesteps(num_timesteps, section_counts):
Expand Down Expand Up @@ -195,4 +205,55 @@ def __call__(self, x, ts, **kwargs):
new_ts = map_tensor[ts]
# if self.rescale_timesteps:
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
return self.model(x, new_ts, **kwargs)


class NaViTSpacedDiffusion_T(SpacedDiffusion_T):
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
"""
Compute training losses for a single timestep.
:param model: the model to evaluate loss on.
:param x_start: List[C x ...] for videos of different resolution.
:param t: a batch of timestep indices.
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:param noise: if specified, the specific Gaussian noise to try to remove.
:return: a dict with the key "loss" containing a tensor of shape [N].
Some mean or variance settings may also have other keys.
"""
model = self._wrap_model(model)
if model_kwargs is None:
model_kwargs = {}
# For video of different resolution case like NaViT training.
assert isinstance(x_start, list)
noise = list(map(th.randn_like, x_start))
x_t = list(map(self.q_sample, *(x_start, t, noise)))
terms = {}

if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
raise NotImplementedError("NaViT only supports `loss_type` == MSE")
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
# [b, F, T, p*p*c]
model_output, video_ids, token_kept_ids = model(x_t, t, **model_kwargs)
_, _, _, out_dim = model_output.shape
model_output, model_var_values = th.split(model_output, out_dim // 2, dim=-1)

if self.model_var_type in [
ModelVarType.LEARNED,
ModelVarType.LEARNED_RANGE,
]:
# TODO: support vb loss
raise NotImplementedError("NaViT only supports fixed `model_var_type`")


assert self.model_mean_type == ModelMeanType.EPSILON, "NaViT only supports `ModelMeanType.EPSILON`"
# [b, F, T, p*p*c]
target = pack_target_as(noise, video_ids, model.model.patch_size, token_kept_ids).to(model_output.dtype)

assert model_output.shape == target.shape, f"{model_output.shape}, {target.shape}"

terms["loss"] = torch.nn.functional.mse_loss(target, model_output)
else:
raise NotImplementedError(self.loss_type)

return terms
Loading