Skip to content

Commit

Permalink
feat: add cogvideox i2v finetune
Browse files Browse the repository at this point in the history
  • Loading branch information
YingqingHe committed Jan 16, 2025
1 parent 14a4128 commit 91dde0e
Show file tree
Hide file tree
Showing 11 changed files with 1,008 additions and 156 deletions.
125 changes: 125 additions & 0 deletions configs/004_cogvideox/cogvideo5b-i2v.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
model:
base_learning_rate: 6e-6
target: src.cogvideo_hf.cogvideo_i2v.CogVideoXI2V
params:
noised_image_input: True
noised_image_dropout: 0.05
# VAE of CogVideoX
first_stage_config:
target: diffusers.AutoencoderKLCogVideoX
params:
pretrained_model_name_or_path: checkpoints/cogvideo/CogVideoX-5b-I2V
subfolder: "vae"

# Text encoder (T5) of CogVideoX
cond_stage_config:
target: src.lvdm.modules.encoders.condition.FrozenT5Embedder
params:
version: "DeepFloyd/t5-v1_1-xxl"
device: "cuda"
max_length: 226
freeze: True

# Denosier model
denoiser_config:
target: diffusers.CogVideoXTransformer3DModel
params:
pretrained_model_name_or_path: checkpoints/cogvideo/CogVideoX-5b-I2V
subfolder: "transformer"
load_dtype: fp16 # bf16 5b / fp16 2B
# revision: null
# variant: null

# Lora module
adapter_config:
target: peft.LoraConfig
params:
r: 4
lora_alpha: 1.0
init_lora_weights: True
target_modules: ["to_k", "to_q", "to_v", "to_out.0"]

# Diffusion sampling scheduler
scheduler_config:
target: diffusers.CogVideoXDPMScheduler
params:
pretrained_model_name_or_path: checkpoints/cogvideo/CogVideoX-5b-I2V
subfolder: scheduler

# data configs
# data:
# target: src.data.lightning_data.DataModuleFromConfig
# params:
# batch_size: 2
# num_workers: 16
# wrap: false
# train:
# target: src.data.cogvideo_dataset.VideoDataset
# params:
# instance_data_root: inputs/data-cartoon-talk #"inputs/t2v/cogvideo/elon_musk_video"
# dataset_name: null
# dataset_config_name: null
# caption_column: "labels.txt"
# video_column: "videos.txt"
# height: 480
# width: 720
# fps: 28
# max_num_frames: 2
# skip_frames_start: 0
# skip_frames_end: 0
# cache_dir: ~/.cache
# id_token: null
# image_to_video: true
data:
target: src.data.lightning_data.DataModuleFromConfig
params:
batch_size: 1
num_workers: 16
wrap: false
train:
target: src.data.datasets.DatasetFromCSV
params:
csv_path: temp/apply_lipstick.csv
height: 480
width: 720
video_length: 49
frame_interval: 1
train: True
image_to_video: true
validation:
target: src.data.datasets.DatasetFromCSV
params:
csv_path: temp/apply_lipstick.csv
height: 480
width: 720
video_length: 49
frame_interval: 1
train: False
image_to_video: true

# training configs
lightning:
trainer:
benchmark: True
num_nodes: 1
accumulate_grad_batches: 2
max_epochs: 2000
precision: 32
callbacks:
image_logger:
target: src.utils.callbacks.ImageLogger
params:
batch_frequency: 100000
max_images: 2
to_local: True # save videos into local files
log_images_kwargs:
unconditional_guidance_scale: 6
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: "{epoch:06}-{step:09}"
save_weights_only: False
# every_n_epochs: 300
every_n_train_steps: 10


158 changes: 63 additions & 95 deletions scripts/inference_cogvideo.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
import os
import sys
import time
import argparse
import json
import numpy as np
from functools import partial
from tqdm import trange, tqdm
import argparse
from einops import repeat
from PIL import Image
from typing import List
from tqdm import trange
from omegaconf import OmegaConf
from einops import rearrange, repeat

import torch
from pytorch_lightning import seed_everything
from typing import List,Union
from omegaconf import ListConfig
import torchvision.transforms as transforms

sys.path.insert(0, os.getcwd())
sys.path.insert(1, f'{os.getcwd()}/src')
from src.base.ddim import DDIMSampler
from src.utils.common_utils import instantiate_from_config
from src.base.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
from src.utils.inference_utils import (
get_target_filelist,
load_model_checkpoint,
load_prompts,
load_inputs_i2v,
load_image_batch,
sample_batch_t2v,
sample_batch_i2v,
save_videos,
save_videos_vbench,
)
Expand All @@ -45,15 +39,15 @@ def get_parser():
#
parser.add_argument("--height", type=int, default=512, help="video height, in pixel space")
parser.add_argument("--width", type=int, default=512, help="video width, in pixel space")
parser.add_argument("--frames", type=int, default=None, help="video frame number, in pixel space")
parser.add_argument("--frames", type=int, default=49, help="video frame number, in pixel space")
parser.add_argument("--fps", type=int, default=24, help="video motion speed. 512 or 1024 model: large value -> slow motion; 256 model: large value -> large motion;")
parser.add_argument("--n_samples_prompt", type=int, default=1, help="num of samples per prompt",)
#
parser.add_argument("--bs", type=int, default=1, help="batch size for inference")
parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
parser.add_argument("--uncond_prompt", type=str, default="", help="unconditional prompts, or negative prompts")
parser.add_argument("--unconditional_guidance_scale", type=float, default=12.0, help="prompt classifier-free guidance")
parser.add_argument("--unconditional_guidance_scale", type=float, default=6.0, help="prompt classifier-free guidance")
parser.add_argument("--unconditional_guidance_scale_temporal", type=float, default=None, help="temporal consistency guidance")
# dc args
parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="i2v: use multi-condition cfg or not")
Expand Down Expand Up @@ -83,10 +77,8 @@ def load_model(args, cuda_idx=0):
assert os.path.exists(args.ckpt_path), f"Error: checkpoint [{args.ckpt_path}] Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
# load lora weights
# notice lora args
if hasattr(model,"lora_args") and len(model.lora_args)!=0:
model.inject_lora()

model.eval()
return model

Expand Down Expand Up @@ -115,45 +107,48 @@ def load_inputs(args):
)
return prompt_list, image_list, filename_list

def get_batch(keys, value_dict, N: Union[List, ListConfig], T=None, device="cuda"):
batch = {}
batch_uc = {}

for key in keys:
if key == "txt":
# import pdb;pdb.set_trace()
batch["txt"] = np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
batch_uc["txt"] = np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
else:
batch[key] = value_dict[key]

if T is not None:
batch["num_video_frames"] = T

for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc

def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))

def save_video_as_grid_and_mp4(video_batch: torch.Tensor,
save_path: str,
filenames=None ,
fps: int = 5
):
for i, vid in enumerate(video_batch):
gif_frames = []
for frame in vid:
frame = rearrange(frame, "c h w -> h w c")
frame = (255.0 * frame).cpu().numpy().astype(np.uint8)
gif_frames.append(frame)
now_save_path = os.path.join(save_path, filenames[i]+f"-{i}.mp4")
print(now_save_path)
with imageio.get_writer(now_save_path, fps=fps) as writer:
for frame in gif_frames:
writer.append_data(frame)
def load_inputs_i2v(input_dir, video_size=(480,720), video_frames=49):
"""
Load prompt list and conditional images for i2v from input_dir.
"""
# load prompt files
prompt_files = get_target_filelist(input_dir, ext='txt')
if len(prompt_files) > 1:
# only use the first one (sorted by name) if multiple exist
print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_files[0])[1]} is used.")
prompt_file = prompt_files[0]
elif len(prompt_files) == 1:
prompt_file = prompt_files[0]
elif len(prompt_files) == 0:
print(prompt_files)
raise ValueError(f"Error: found NO prompt file in {input_dir}")
prompt_list = load_prompts(prompt_file)
n_samples = len(prompt_list)

## load images
img_list = get_target_filelist(input_dir, ext='[mpj][pne][4gj]')
# img_list = get_target_filelist(input_dir, ext='[mpjw][pne][4gjb][p]')
print(f"Found {n_samples} prompts and {len(img_list)} images in {input_dir}")
# image transforms
transform = transforms.Compose([
transforms.Resize(min(video_size)),
transforms.CenterCrop(video_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

image_list = []
filename_list = []
for idx in range(n_samples):

image = Image.open(img_list[idx]).convert('RGB')
# image_tensor = transform(image).unsqueeze(0) # [c,h,w]
# frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
image_list.append(image)

_, filename = os.path.split(img_list[idx])
filename_list.append(filename.split(".")[0])

return filename_list, image_list, prompt_list

def run_inference_cogvideo(args, gpu_num=1, rank=0, **kwargs):
"""
Expand Down Expand Up @@ -187,62 +182,35 @@ def run_inference_cogvideo(args, gpu_num=1, rank=0, **kwargs):
n_iters = len(prompt_list_rank) // args.bs + (1 if len(prompt_list_rank) % args.bs else 0)
with torch.no_grad():
for idx in trange(0, n_iters, desc="Sample Iters"):
# print(f'[rank:{rank}] batch {idx}: prompt bs {args.bs}) x nsamples_per_prompt {args.n_samples_prompt} ...')

# split batch
prompts = prompt_list_rank[idx*args.bs:(idx+1)*args.bs]
filenames = filename_list_rank[idx*args.bs:(idx+1)*args.bs]

if args.mode == 'i2v':
images = image_list_rank[idx*args.bs:(idx+1)*args.bs]
if isinstance(images, list):
images = torch.stack(images, dim=0).to("cuda")
else:
images = images.unsqueeze(0).to("cuda")
# idx_s = idx*args.bs
# idx_e = min(idx_s+args.bs, len(prompt_list_rank))
# batch_size = idx_e - idx_s
# filenames = filename_list_rank[idx_s:idx_e]

# prompts = prompt_list_rank[idx_s:idx_e]
# if isinstance(prompts, str):
# prompts = [prompts]
#prompts = batch_size * [""]

# if args.mode == 't2v':
# cond = {"c_crossattn": [text_emb], "fps": fps}

# TODO
# elif args.mode == 'i2v':
# cond_images = load_image_batch(image_list_rank[idx_s:idx_e], (args.height, args.width))
# cond_images = cond_images.to(model.device)
# img_emb = model.get_image_embeds(cond_images)
# imtext_cond = torch.cat([text_emb, img_emb], dim=1)
# cond = {"c_crossattn": [imtext_cond], "fps": fps}
# else:
# raise NotImplementedError

## inference
bs = args.bs if args.bs == len(prompts) else len(prompts)
# noise_shape = [bs, channels, frames, h, w]
if args.mode == 't2v':
# batch_samples = sample_batch_t2v(model, ddim_sampler, prompts, noise_shape, args.fps,
# args.n_samples_prompt, args.ddim_steps, args.ddim_eta,
# args.unconditional_guidance_scale, args.unconditional_guidance_scale_temporal,
# args.uncond_prompt,
# )
batch_samples = model.sample(
prompts,
None,
height = args.height,
width = args.width,
num_frames = 12,
num_frames = 49,
num_videos_per_prompt = args.n_samples_prompt,
guidance_scale = args.unconditional_guidance_scale,
# args.unconditional_guidance_scale_temporal,
# args.uncond_prompt
)
elif args.mode == 'i2v':
raise NotImplementedError
batch_samples = model.sample(
images,
prompts,
None,
height = args.height,
width = args.width,
num_frames = 49,
num_videos_per_prompt = args.n_samples_prompt,
guidance_scale = args.unconditional_guidance_scale,
)
else:
raise ValueError

Expand Down
16 changes: 16 additions & 0 deletions shscripts/inference_cogvideo_i2v_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
config=configs/004_cogvideox/cogvideo5b-i2v.yaml
ckpt=results/train/cogvideox_i2v_5b/$YOUR_CKPT_PATH.ckpt
prompt_dir=$YOUR_PROMPT_DIR

current_time=$(date +%Y%m%d%H%M%S)
savedir="results/inference/i2v/cogvideox-i2v-lora-$current_time"

python3 scripts/inference_cogvideo.py \
--config $config \
--ckpt_path $ckpt \
--prompt_dir $prompt_dir \
--savedir $savedir \
--bs 1 --height 480 --width 720 \
--fps 16 \
--seed 6666 \
--mode i2v
19 changes: 19 additions & 0 deletions shscripts/train_cogvideox_i2v_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
export TOKENIZERS_PARALLELISM=false

# dependencies
CONFIG="configs/004_cogvideox/cogvideo5b-i2v.yaml" # experiment config

# exp saving directory: ${RESROOT}/${CURRENT_TIME}_${EXPNAME}
RESROOT="results/train" # experiment saving directory
EXPNAME="cogvideox_i2v_5b" # experiment name
CURRENT_TIME=$(date +%Y%m%d%H%M%S) # current time

# run
python scripts/train.py \
-t \
--base $CONFIG \
--logdir $RESROOT \
--name "$CURRENT_TIME"_$EXPNAME \
--devices '0,' \
lightning.trainer.num_nodes=1 \
--auto_resume
Loading

0 comments on commit 91dde0e

Please sign in to comment.