diff --git a/README.md b/README.md index ab7bf35b..9d143b0f 100644 --- a/README.md +++ b/README.md @@ -361,6 +361,52 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 scripts/inference.py config :warning: **LIMITATION**: The sequence parallelism is not supported for gradio deployment. For now, the sequence parallelism is only supported when the dimension can be divided by the number of GPUs. Thus, it may fail for some cases. We tested 4 GPUs for 720p and 2 GPUs for 480p. +### Separate Inference 720p video with 24G VRAM + +The Open-Sora project consists of three main components: text_encoder, VAE, and STDiT. By running inference for each component separately, we can complete the entire process with limited VRAM. + +#### Step-by-Step Inference + +1. First, run the text_encoder inference and save text embedding. +2. If you are using reference image, run VAE encoder and save reference latents. (optional) +3. Then, run STDiT with saved text embedding and save latents. +4. Finally, run VAE decoder with saved latents. + +#### All-in-one script + +The basic text2video script is as follows: + +```bash +# text to video +./scripts/separate_inference.sh "0,1" 4s 720p "9:16" 7 "a beautiful waterfall" + +# Parameter explanations: +# "0,1" : GPU indices to use (e.g., "0,1" uses GPUs 0 and 1) +# 4s : Duration of the video (4 seconds in this case) +# 720p : Resolution of the video (720p) +# "9:16" : Aspect ratio of the video (9:16 vertical format) +# 7 : aes score +# "a beautiful waterfall" : Text prompt describing the video content +``` + +The basic image2video script is as follows: + +1. Generate image with text2video. + +```bash +# text to image +./scripts/separate_inference.sh "0,1" 1 1080p "9:16" 7 "a beautiful waterfall" +``` + +2. Generate video with reference image. + +```bash +# image to video +./scripts/separate_inference.sh "0,1" 4s 720p "9:16" 7 "a beautiful waterfall. {\"reference_path\": \"path2reference.png\",\"mask_strategy\": \"0\"}"" +``` + +:warning: **LIMITATION**: Due to the text_encoder requiring over 18GB of VRAM, this script currently only supports GPUs with more than 18GB of VRAM. If you want to use it on machines with less than 16GB of VRAM, you'll need to modify the inference precision of the text_encoder yourself. + ### GPT-4o Prompt Refinement We find that GPT-4o can refine the prompt and improve the quality of the generated video. With this feature, you can also use other language (e.g., Chinese) as the prompt. To enable this feature, you need prepare your openai api key in the environment: diff --git a/opensora/schedulers/rf/__init__.py b/opensora/schedulers/rf/__init__.py index e5931ecf..dc7bd44d 100644 --- a/opensora/schedulers/rf/__init__.py +++ b/opensora/schedulers/rf/__init__.py @@ -42,6 +42,8 @@ def sample( mask=None, guidance_scale=None, progress=True, + caption_embs=None, + caption_emb_masks=None, ): # if no specific guidance scale is provided, use the default scale when initializing the scheduler if guidance_scale is None: @@ -49,9 +51,16 @@ def sample( n = len(prompts) # text encoding - model_args = text_encoder.encode(prompts) - y_null = text_encoder.null(n) - model_args["y"] = torch.cat([model_args["y"], y_null], 0) + if text_encoder is not None: + model_args = text_encoder.encode(prompts) + y_null = text_encoder.null(n) + model_args["y"] = torch.cat([model_args["y"], y_null], 0) + else: + # use pre-inference text embeddings + model_args = dict(mask=caption_emb_masks) + y_null = model.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] + model_args["y"] = torch.cat([caption_embs, y_null], 0) + if additional_args is not None: model_args.update(additional_args) diff --git a/scripts/separate_inference.sh b/scripts/separate_inference.sh new file mode 100755 index 00000000..279b2137 --- /dev/null +++ b/scripts/separate_inference.sh @@ -0,0 +1,24 @@ +#!/bin/bash +set_default_params() { + gpus=${1:-"0,1"} + num_frames=${2:-"4s"} + resolution=${3:-"720p"} + aspect_ratio=${4:-"9:16"} + aes=${5:-"7"} + prompt=${6:-"Create a video featuring Will Smith enjoying a plate of spaghetti."} +} + +set_default_params "$@" + +export CUDA_VISIBLE_DEVICES=$gpus + +gpus="${gpus// /}" +IFS=',' read -ra gpu_array <<< "$gpus" +gpu_count=${#gpu_array[@]} + +torchrun --nproc_per_node $gpu_count --master_port=23456 scripts/separate_inference/inference_text_encoder.py configs/opensora-v1-2/inference/sample.py --aes $aes --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" +if echo "$prompt" | grep -q "reference_path"; then + torchrun --nproc_per_node $gpu_count --master_port=23456 scripts/separate_inference/inference_vae_encoder.py configs/opensora-v1-2/inference/sample.py --aes $aes --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" +fi +torchrun --nproc_per_node $gpu_count --master_port=23456 scripts/separate_inference/inference_stdit.py configs/opensora-v1-2/inference/sample.py --aes $aes --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" +torchrun --nproc_per_node $gpu_count --master_port=23456 scripts/separate_inference/inference_vae_decoder.py configs/opensora-v1-2/inference/sample.py --aes $aes --num-frames "$num_frames" --resolution "$resolution" --aspect-ratio "$aspect_ratio" --prompt "$prompt" diff --git a/scripts/separate_inference/inference_stdit.py b/scripts/separate_inference/inference_stdit.py new file mode 100644 index 00000000..4f8a20a6 --- /dev/null +++ b/scripts/separate_inference/inference_stdit.py @@ -0,0 +1,313 @@ +import os +from datetime import date, datetime, timedelta +from pprint import pformat + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed +from tqdm import tqdm + +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.datasets.aspect import get_image_size, get_num_frames +from opensora.models.text_encoder.t5 import text_preprocessing +from opensora.registry import MODELS, SCHEDULERS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.inference_utils import ( + append_score_to_prompts, + apply_mask_strategy, + extract_json_from_prompts, + extract_prompts_loop, + get_save_path_name, + load_prompts, + merge_prompt, + prepare_multi_resolution_info, + refine_prompts_by_openai, + split_prompt, +) +from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype + + +def main(): + torch.set_grad_enabled(False) + # ====================================================== + # configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=False) + + # == device and dtype == + device = "cuda" if torch.cuda.is_available() else "cpu" + cfg_dtype = cfg.get("dtype", "fp32") + assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # == init distributed env == + if is_distributed(): + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + enable_sequence_parallelism = coordinator.world_size > 1 + if enable_sequence_parallelism: + set_sequence_parallel_group(dist.group.WORLD) + else: + coordinator = None + enable_sequence_parallelism = False + set_random_seed(seed=cfg.get("seed", 1024)) + + # == init logger == + logger = create_logger() + logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) + verbose = cfg.get("verbose", 1) + progress_wrap = tqdm if verbose == 1 else (lambda x: x) + + # ====================================================== + # build model & load weights + # ====================================================== + logger.info("Building models...") + + # == prepare video size == + image_size = cfg.get("image_size", None) + if image_size is None: + resolution = cfg.get("resolution", None) + aspect_ratio = cfg.get("aspect_ratio", None) + assert ( + resolution is not None and aspect_ratio is not None + ), "resolution and aspect_ratio must be provided if image_size is not provided" + image_size = get_image_size(resolution, aspect_ratio) + num_frames = get_num_frames(cfg.num_frames) + + # == get vae temporal size == + micro_frame_size = cfg.get("micro_frame_size", 17) + time_padding = 0 if micro_frame_size % 4 == 0 else 4 - micro_frame_size % 4 + lsize = (micro_frame_size + time_padding) // 4 + frame_size = lsize * (num_frames // micro_frame_size) + remain_temporal_size = num_frames % micro_frame_size + if remain_temporal_size > 0: + time_padding = 0 if remain_temporal_size % 4 == 0 else 4 - remain_temporal_size % 4 + remain_size = (remain_temporal_size + time_padding) // 4 + frame_size += remain_size + + latent_size = (frame_size, image_size[0] // 8, image_size[1] // 8) + model = ( + build_module( + cfg.model, + MODELS, + input_size=latent_size, + in_channels=4, + caption_channels=4096, + model_max_length=300, + enable_sequence_parallelism=enable_sequence_parallelism, + ) + .to(device, dtype) + .eval() + ) + # == build scheduler == + scheduler = build_module(cfg.scheduler, SCHEDULERS) + + # ====================================================== + # inference + # ====================================================== + # == load prompts == + prompts = cfg.get("prompt", None) + start_idx = cfg.get("start_index", 0) + if prompts is None: + if cfg.get("prompt_path", None) is not None: + prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) + else: + prompts = [cfg.get("prompt_generator", "")] * 1_000_000 # endless loop + + # == prepare reference == + reference_path = cfg.get("reference_path", [""] * len(prompts)) + mask_strategy = cfg.get("mask_strategy", [""] * len(prompts)) + assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts" + assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" + + # == prepare arguments == + fps = cfg.fps + cfg.get("save_fps", fps // cfg.get("frame_interval", 1)) + multi_resolution = cfg.get("multi_resolution", None) + batch_size = cfg.get("batch_size", 1) + num_sample = cfg.get("num_sample", 1) + loop = cfg.get("loop", 1) + cfg.get("condition_frame_length", 5) + cfg.get("condition_frame_edit", 0.0) + align = cfg.get("align", None) + + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + sample_name = cfg.get("sample_name", None) + prompt_as_path = cfg.get("prompt_as_path", False) + + # == prepare saved dir == + cur_date = datetime.now().strftime("%Y-%m-%d") + if not os.path.exists(os.path.join(save_dir, cur_date)): + yesterday = date.today() - timedelta(days=1) + cur_date = yesterday.strftime("%Y-%m-%d") + latest_idx = sorted([int(x) for x in os.listdir(os.path.join(save_dir, cur_date))])[-1] + saved_idx = str(latest_idx).zfill(5) + + # == Iter over all samples == + for i in progress_wrap(range(0, len(prompts), batch_size)): + # == prepare batch prompts == + batch_prompts = prompts[i : i + batch_size] + ms = mask_strategy[i : i + batch_size] + refs = reference_path[i : i + batch_size] + + # == get json from prompts == + batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms) + original_batch_prompts = batch_prompts + + # == get reference for condition == + # refs = collect_references_batch(refs, vae, image_size) + refs_x = [] # refs_x: [batch, ref_num, C, T, H, W] + for reference_path in refs: + if reference_path == "": + refs_x.append([]) + continue + ref_path = reference_path.split(";") + ref = [] + for r_path in ref_path: + ref.append("placehold") + refs_x.append(ref) + refs = refs_x + + # == multi-resolution info == + model_args = prepare_multi_resolution_info( + multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype + ) + + # == Iter over number of sampling for one prompt == + for k in range(num_sample): + # == prepare save paths == + save_paths = [ + get_save_path_name( + save_dir, + sample_name=sample_name, + sample_idx=start_idx + idx, + prompt=original_batch_prompts[idx], + prompt_as_path=prompt_as_path, + num_sample=num_sample, + k=k, + ) + for idx in range(len(batch_prompts)) + ] + + # NOTE: Skip if the sample already exists + # This is useful for resuming sampling VBench + if prompt_as_path and all_exists(save_paths): + continue + + # == process prompts step by step == + # 0. split prompt + # each element in the list is [prompt_segment_list, loop_idx_list] + batched_prompt_segment_list = [] + batched_loop_idx_list = [] + for prompt in batch_prompts: + prompt_segment_list, loop_idx_list = split_prompt(prompt) + batched_prompt_segment_list.append(prompt_segment_list) + batched_loop_idx_list.append(loop_idx_list) + + # 1. refine prompt by openai + if cfg.get("llm_refine", False): + # only call openai API when + # 1. seq parallel is not enabled + # 2. seq parallel is enabled and the process is rank 0 + if not enable_sequence_parallelism or (enable_sequence_parallelism and is_main_process()): + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list) + + # sync the prompt if using seq parallel + if enable_sequence_parallelism: + coordinator.block_all() + prompt_segment_length = [ + len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list + ] + + # flatten the prompt segment list + batched_prompt_segment_list = [ + prompt_segment + for prompt_segment_list in batched_prompt_segment_list + for prompt_segment in prompt_segment_list + ] + + # create a list of size equal to world size + broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size + dist.broadcast_object_list(broadcast_obj_list, 0) + + # recover the prompt list + batched_prompt_segment_list = [] + segment_start_idx = 0 + all_prompts = broadcast_obj_list[0] + for num_segment in prompt_segment_length: + batched_prompt_segment_list.append( + all_prompts[segment_start_idx : segment_start_idx + num_segment] + ) + segment_start_idx += num_segment + + # 2. append score + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = append_score_to_prompts( + prompt_segment_list, + aes=cfg.get("aes", None), + flow=cfg.get("flow", None), + camera_motion=cfg.get("camera_motion", None), + ) + + # 3. clean prompt with T5 + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list] + + # 4. merge to obtain the final prompt + batch_prompts = [] + for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): + batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) + + # == Iter over loop generation == + for loop_i in range(loop): + # == get prompt for loop i == + batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i) + + # == add condition frames for loop == + if os.path.exists(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ref.pt")): + ref = torch.load(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ref.pt")) + ref = ref.to(dtype) + ref = ref.to(device) + refs[i][loop_i] = ref + ms[i] = open(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ms")).readlines()[0].strip() + + # == get text embedding == + caption_embs = torch.load(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_prompt.pt")) + caption_emb_masks = torch.load( + os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_prompt_masks.pt") + ) + caption_embs = caption_embs.to(device, torch.float32) + caption_emb_masks = caption_emb_masks.to(device, torch.int64) + + # == sampling == + torch.manual_seed(1024) + z = torch.randn(len(batch_prompts), 4, *latent_size, device=device, dtype=dtype) + masks = apply_mask_strategy(z, refs, ms, loop_i, align=align) + samples = scheduler.sample( + model, + text_encoder=None, + z=z, + prompts=batch_prompts_loop, + device=device, + additional_args=model_args, + progress=verbose >= 2, + mask=masks, + caption_embs=caption_embs, + caption_emb_masks=caption_emb_masks, + ) + if is_main_process(): + torch.save(samples.cpu(), os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_latents.pt")) + start_idx += len(batch_prompts) + logger.info("Inference STDiT finished.") + logger.info("Saved %s samples to %s/%s/%s", start_idx, save_dir, cur_date, saved_idx) + + +if __name__ == "__main__": + main() diff --git a/scripts/separate_inference/inference_text_encoder.py b/scripts/separate_inference/inference_text_encoder.py new file mode 100644 index 00000000..e103b6a4 --- /dev/null +++ b/scripts/separate_inference/inference_text_encoder.py @@ -0,0 +1,222 @@ +import os +from datetime import datetime +from pprint import pformat + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed +from tqdm import tqdm + +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.models.text_encoder.t5 import text_preprocessing +from opensora.registry import MODELS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.inference_utils import ( + append_score_to_prompts, + extract_json_from_prompts, + extract_prompts_loop, + get_save_path_name, + load_prompts, + merge_prompt, + refine_prompts_by_openai, + split_prompt, +) +from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype + + +def main(): + torch.set_grad_enabled(False) + # ====================================================== + # configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=False) + + # == device and dtype == + device = "cuda" if torch.cuda.is_available() else "cpu" + cfg_dtype = cfg.get("dtype", "fp32") + assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" + to_torch_dtype(cfg.get("dtype", "bf16")) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # == init distributed env == + if is_distributed(): + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + enable_sequence_parallelism = coordinator.world_size > 1 + if enable_sequence_parallelism: + set_sequence_parallel_group(dist.group.WORLD) + else: + coordinator = None + enable_sequence_parallelism = False + set_random_seed(seed=cfg.get("seed", 1024)) + + # == init logger == + logger = create_logger() + logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) + verbose = cfg.get("verbose", 1) + progress_wrap = tqdm if verbose == 1 else (lambda x: x) + + # ====================================================== + # build model & load weights + # ====================================================== + logger.info("Building models...") + # == build text-encoder and vae == + text_encoder = build_module(cfg.text_encoder, MODELS, device=device) + # ====================================================== + # inference + # ====================================================== + # == load prompts == + prompts = cfg.get("prompt", None) + start_idx = cfg.get("start_index", 0) + if prompts is None: + if cfg.get("prompt_path", None) is not None: + prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) + else: + prompts = [cfg.get("prompt_generator", "")] * 1_000_000 # endless loop + + # == prepare reference == + reference_path = cfg.get("reference_path", [""] * len(prompts)) + mask_strategy = cfg.get("mask_strategy", [""] * len(prompts)) + assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts" + assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" + + # == prepare arguments == + batch_size = cfg.get("batch_size", 1) + num_sample = cfg.get("num_sample", 1) + loop = cfg.get("loop", 1) + + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + + # == prepare saved dir == + date = datetime.now().strftime("%Y-%m-%d") + if not os.path.exists(os.path.join(save_dir, date)): + saved_idx = str(1).zfill(5) + else: + latest_idx = sorted([int(x) for x in os.listdir(os.path.join(save_dir, date))])[-1] + saved_idx = str(latest_idx + 1).zfill(5) + if is_main_process(): + os.makedirs(os.path.join(save_dir, date, saved_idx), exist_ok=True) + + sample_name = cfg.get("sample_name", None) + prompt_as_path = cfg.get("prompt_as_path", False) + # == Iter over all samples == + for i in progress_wrap(range(0, len(prompts), batch_size)): + # == prepare batch prompts == + batch_prompts = prompts[i : i + batch_size] + ms = mask_strategy[i : i + batch_size] + refs = reference_path[i : i + batch_size] + + # == get json from prompts == + batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms) + original_batch_prompts = batch_prompts + + # == Iter over number of sampling for one prompt == + for k in range(num_sample): + # == prepare save paths == + save_paths = [ + get_save_path_name( + save_dir, + sample_name=sample_name, + sample_idx=start_idx + idx, + prompt=original_batch_prompts[idx], + prompt_as_path=prompt_as_path, + num_sample=num_sample, + k=k, + ) + for idx in range(len(batch_prompts)) + ] + + # NOTE: Skip if the sample already exists + # This is useful for resuming sampling VBench + if prompt_as_path and all_exists(save_paths): + continue + + # == process prompts step by step == + # 0. split prompt + # each element in the list is [prompt_segment_list, loop_idx_list] + batched_prompt_segment_list = [] + batched_loop_idx_list = [] + for prompt in batch_prompts: + prompt_segment_list, loop_idx_list = split_prompt(prompt) + batched_prompt_segment_list.append(prompt_segment_list) + batched_loop_idx_list.append(loop_idx_list) + + # 1. refine prompt by openai + if cfg.get("llm_refine", False): + # only call openai API when + # 1. seq parallel is not enabled + # 2. seq parallel is enabled and the process is rank 0 + if not enable_sequence_parallelism or (enable_sequence_parallelism and is_main_process()): + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list) + + # sync the prompt if using seq parallel + if enable_sequence_parallelism: + coordinator.block_all() + prompt_segment_length = [ + len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list + ] + + # flatten the prompt segment list + batched_prompt_segment_list = [ + prompt_segment + for prompt_segment_list in batched_prompt_segment_list + for prompt_segment in prompt_segment_list + ] + + # create a list of size equal to world size + broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size + dist.broadcast_object_list(broadcast_obj_list, 0) + + # recover the prompt list + batched_prompt_segment_list = [] + segment_start_idx = 0 + all_prompts = broadcast_obj_list[0] + for num_segment in prompt_segment_length: + batched_prompt_segment_list.append( + all_prompts[segment_start_idx : segment_start_idx + num_segment] + ) + segment_start_idx += num_segment + + # 2. append score + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = append_score_to_prompts( + prompt_segment_list, + aes=cfg.get("aes", None), + flow=cfg.get("flow", None), + camera_motion=cfg.get("camera_motion", None), + ) + + # 3. clean prompt with T5 + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list] + + # 4. merge to obtain the final prompt + batch_prompts = [] + for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): + batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) + + # == Iter over loop generation == + for loop_i in range(loop): + # == get prompt for loop i == + batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i) + + # == run text encoder == + caption_embs, emb_masks = text_encoder.t5.get_text_embeddings(batch_prompts_loop) + caption_embs = caption_embs[:, None] + if is_main_process(): + torch.save(caption_embs, os.path.join(save_dir, date, saved_idx, f"{i}_{loop_i}_prompt.pt")) + torch.save(emb_masks, os.path.join(save_dir, date, saved_idx, f"{i}_{loop_i}_prompt_masks.pt")) + + start_idx += len(batch_prompts) + logger.info("Inference text_encoder finished.") + logger.info("Saved %s samples to %s/%s/%s", start_idx, save_dir, date, saved_idx) + + +if __name__ == "__main__": + main() diff --git a/scripts/separate_inference/inference_vae_decoder.py b/scripts/separate_inference/inference_vae_decoder.py new file mode 100644 index 00000000..0a630a1f --- /dev/null +++ b/scripts/separate_inference/inference_vae_decoder.py @@ -0,0 +1,143 @@ +import os +import time +from datetime import date, datetime, timedelta +from pprint import pformat + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed +from tqdm import tqdm + +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.datasets import save_sample +from opensora.datasets.aspect import get_num_frames +from opensora.registry import MODELS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.inference_utils import add_watermark, dframe_to_frame, load_prompts +from opensora.utils.misc import create_logger, is_distributed, is_main_process, to_torch_dtype + + +def main(): + torch.set_grad_enabled(False) + # ====================================================== + # configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=False) + + # == device and dtype == + device = "cuda" if torch.cuda.is_available() else "cpu" + cfg_dtype = cfg.get("dtype", "fp32") + assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # == init distributed env == + if is_distributed(): + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + enable_sequence_parallelism = coordinator.world_size > 1 + if enable_sequence_parallelism: + set_sequence_parallel_group(dist.group.WORLD) + else: + coordinator = None + enable_sequence_parallelism = False + set_random_seed(seed=cfg.get("seed", 1024)) + + # == init logger == + logger = create_logger() + logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) + verbose = cfg.get("verbose", 1) + progress_wrap = tqdm if verbose == 1 else (lambda x: x) + + # ====================================================== + # build model & load weights + # ====================================================== + logger.info("Building models...") + # == build text-encoder and vae == + vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + + # ====================================================== + # inference + # ====================================================== + # == load prompts == + prompts = cfg.get("prompt", None) + start_idx = cfg.get("start_index", 0) + if prompts is None: + if cfg.get("prompt_path", None) is not None: + prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) + else: + prompts = [cfg.get("prompt_generator", "")] * 1_000_000 # endless loop + + # == prepare reference == + reference_path = cfg.get("reference_path", [""] * len(prompts)) + mask_strategy = cfg.get("mask_strategy", [""] * len(prompts)) + assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts" + assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" + + # == prepare arguments == + num_frames = get_num_frames(cfg.num_frames) + fps = cfg.fps + save_fps = cfg.get("save_fps", fps // cfg.get("frame_interval", 1)) + cfg.get("multi_resolution", None) + batch_size = cfg.get("batch_size", 1) + cfg.get("num_sample", 1) + loop = cfg.get("loop", 1) + condition_frame_length = cfg.get("condition_frame_length", 5) + cfg.get("condition_frame_edit", 0.0) + cfg.get("align", None) + + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + cfg.get("sample_name", None) + cfg.get("prompt_as_path", False) + + # == prepare saved dir == + cur_date = datetime.now().strftime("%Y-%m-%d") + if not os.path.exists(os.path.join(save_dir, cur_date)): + yesterday = date.today() - timedelta(days=1) + cur_date = yesterday.strftime("%Y-%m-%d") + latest_idx = sorted([int(x) for x in os.listdir(os.path.join(save_dir, cur_date))])[-1] + saved_idx = str(latest_idx).zfill(5) + + # == Iter over all samples == + for i in progress_wrap(range(0, len(prompts), batch_size)): + batch_prompts = prompts[i : i + batch_size] + # == Iter over loop generation == + video_clips = [] + for loop_i in range(loop): + # == get prompt for loop i == + samples = torch.load(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_latents.pt")) + samples = samples.to(device, dtype) + samples = vae.decode(samples.to(dtype), num_frames=num_frames) + video_clips.append(samples) + + # == save samples == + if is_main_process(): + for idx, batch_prompt in enumerate(batch_prompts): + if verbose >= 2: + logger.info("Prompt: %s", batch_prompt) + save_path = os.path.join(save_dir, cur_date, saved_idx, "video") + video = [video_clips[i][idx] for i in range(loop)] + for i in range(1, loop): + video[i] = video[i][:, dframe_to_frame(condition_frame_length) :] + video = torch.cat(video, dim=1) + save_path = save_sample( + video, + fps=save_fps, + save_path=save_path, + verbose=verbose >= 2, + ) + if save_path.endswith(".mp4") and cfg.get("watermark", False): + time.sleep(1) # prevent loading previous generated video + add_watermark(save_path) + start_idx += len(batch_prompts) + logger.info("Inference VAE decoder finished.") + logger.info("Saved %s samples to %s/%s/%s", start_idx, save_dir, cur_date, saved_idx) + + +if __name__ == "__main__": + main() diff --git a/scripts/separate_inference/inference_vae_encoder.py b/scripts/separate_inference/inference_vae_encoder.py new file mode 100644 index 00000000..aa9b4b90 --- /dev/null +++ b/scripts/separate_inference/inference_vae_encoder.py @@ -0,0 +1,254 @@ +import os +from datetime import date, datetime, timedelta +from pprint import pformat + +import colossalai +import torch +import torch.distributed as dist +from colossalai.cluster import DistCoordinator +from mmengine.runner import set_random_seed +from tqdm import tqdm + +from opensora.acceleration.parallel_states import set_sequence_parallel_group +from opensora.datasets.aspect import get_image_size, get_num_frames +from opensora.models.text_encoder.t5 import text_preprocessing +from opensora.registry import MODELS, build_module +from opensora.utils.config_utils import parse_configs +from opensora.utils.inference_utils import ( + append_generated, + append_score_to_prompts, + collect_references_batch, + extract_json_from_prompts, + extract_prompts_loop, + get_save_path_name, + load_prompts, + merge_prompt, + refine_prompts_by_openai, + split_prompt, +) +from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype + + +def main(): + torch.set_grad_enabled(False) + # ====================================================== + # configs & runtime variables + # ====================================================== + # == parse configs == + cfg = parse_configs(training=False) + + # == device and dtype == + device = "cuda" if torch.cuda.is_available() else "cpu" + cfg_dtype = cfg.get("dtype", "fp32") + assert cfg_dtype in ["fp16", "bf16", "fp32"], f"Unknown mixed precision {cfg_dtype}" + dtype = to_torch_dtype(cfg.get("dtype", "bf16")) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + # == init distributed env == + if is_distributed(): + colossalai.launch_from_torch({}) + coordinator = DistCoordinator() + enable_sequence_parallelism = coordinator.world_size > 1 + if enable_sequence_parallelism: + set_sequence_parallel_group(dist.group.WORLD) + else: + coordinator = None + enable_sequence_parallelism = False + set_random_seed(seed=cfg.get("seed", 1024)) + + # == init logger == + logger = create_logger() + logger.info("Inference configuration:\n %s", pformat(cfg.to_dict())) + verbose = cfg.get("verbose", 1) + progress_wrap = tqdm if verbose == 1 else (lambda x: x) + + # ====================================================== + # build model & load weights + # ====================================================== + logger.info("Building models...") + # == build text-encoder and vae == + vae = build_module(cfg.vae, MODELS).to(device, dtype).eval() + + # == prepare video size == + image_size = cfg.get("image_size", None) + if image_size is None: + resolution = cfg.get("resolution", None) + aspect_ratio = cfg.get("aspect_ratio", None) + assert ( + resolution is not None and aspect_ratio is not None + ), "resolution and aspect_ratio must be provided if image_size is not provided" + image_size = get_image_size(resolution, aspect_ratio) + num_frames = get_num_frames(cfg.num_frames) + + # == build diffusion model == + input_size = (num_frames, *image_size) + vae.get_latent_size(input_size) + + # ====================================================== + # inference + # ====================================================== + # == load prompts == + prompts = cfg.get("prompt", None) + start_idx = cfg.get("start_index", 0) + if prompts is None: + if cfg.get("prompt_path", None) is not None: + prompts = load_prompts(cfg.prompt_path, start_idx, cfg.get("end_index", None)) + else: + prompts = [cfg.get("prompt_generator", "")] * 1_000_000 # endless loop + + # == prepare reference == + reference_path = cfg.get("reference_path", [""] * len(prompts)) + mask_strategy = cfg.get("mask_strategy", [""] * len(prompts)) + assert len(reference_path) == len(prompts), "Length of reference must be the same as prompts" + assert len(mask_strategy) == len(prompts), "Length of mask_strategy must be the same as prompts" + + # == prepare arguments == + fps = cfg.fps + cfg.get("save_fps", fps // cfg.get("frame_interval", 1)) + cfg.get("multi_resolution", None) + batch_size = cfg.get("batch_size", 1) + num_sample = cfg.get("num_sample", 1) + loop = cfg.get("loop", 1) + condition_frame_length = cfg.get("condition_frame_length", 5) + condition_frame_edit = cfg.get("condition_frame_edit", 0.0) + cfg.get("align", None) + + save_dir = cfg.save_dir + os.makedirs(save_dir, exist_ok=True) + sample_name = cfg.get("sample_name", None) + prompt_as_path = cfg.get("prompt_as_path", False) + + # == prepare saved dir == + cur_date = datetime.now().strftime("%Y-%m-%d") + if not os.path.exists(os.path.join(save_dir, cur_date)): + yesterday = date.today() - timedelta(days=1) + cur_date = yesterday.strftime("%Y-%m-%d") + latest_idx = sorted([int(x) for x in os.listdir(os.path.join(save_dir, cur_date))])[-1] + saved_idx = str(latest_idx).zfill(5) + + # == Iter over all samples == + for i in progress_wrap(range(0, len(prompts), batch_size)): + # == prepare batch prompts == + batch_prompts = prompts[i : i + batch_size] + ms = mask_strategy[i : i + batch_size] + refs = reference_path[i : i + batch_size] + + # == get json from prompts == + batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms) + original_batch_prompts = batch_prompts + + # == get reference for condition == + refs = collect_references_batch(refs, vae, image_size) + + # == Iter over number of sampling for one prompt == + for k in range(num_sample): + # == prepare save paths == + save_paths = [ + get_save_path_name( + save_dir, + sample_name=sample_name, + sample_idx=start_idx + idx, + prompt=original_batch_prompts[idx], + prompt_as_path=prompt_as_path, + num_sample=num_sample, + k=k, + ) + for idx in range(len(batch_prompts)) + ] + + # NOTE: Skip if the sample already exists + # This is useful for resuming sampling VBench + if prompt_as_path and all_exists(save_paths): + continue + + # == process prompts step by step == + # 0. split prompt + # each element in the list is [prompt_segment_list, loop_idx_list] + batched_prompt_segment_list = [] + batched_loop_idx_list = [] + for prompt in batch_prompts: + prompt_segment_list, loop_idx_list = split_prompt(prompt) + batched_prompt_segment_list.append(prompt_segment_list) + batched_loop_idx_list.append(loop_idx_list) + + # 1. refine prompt by openai + if cfg.get("llm_refine", False): + # only call openai API when + # 1. seq parallel is not enabled + # 2. seq parallel is enabled and the process is rank 0 + if not enable_sequence_parallelism or (enable_sequence_parallelism and is_main_process()): + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = refine_prompts_by_openai(prompt_segment_list) + + # sync the prompt if using seq parallel + if enable_sequence_parallelism: + coordinator.block_all() + prompt_segment_length = [ + len(prompt_segment_list) for prompt_segment_list in batched_prompt_segment_list + ] + + # flatten the prompt segment list + batched_prompt_segment_list = [ + prompt_segment + for prompt_segment_list in batched_prompt_segment_list + for prompt_segment in prompt_segment_list + ] + + # create a list of size equal to world size + broadcast_obj_list = [batched_prompt_segment_list] * coordinator.world_size + dist.broadcast_object_list(broadcast_obj_list, 0) + + # recover the prompt list + batched_prompt_segment_list = [] + segment_start_idx = 0 + all_prompts = broadcast_obj_list[0] + for num_segment in prompt_segment_length: + batched_prompt_segment_list.append( + all_prompts[segment_start_idx : segment_start_idx + num_segment] + ) + segment_start_idx += num_segment + + # 2. append score + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = append_score_to_prompts( + prompt_segment_list, + aes=cfg.get("aes", None), + flow=cfg.get("flow", None), + camera_motion=cfg.get("camera_motion", None), + ) + + # 3. clean prompt with T5 + for idx, prompt_segment_list in enumerate(batched_prompt_segment_list): + batched_prompt_segment_list[idx] = [text_preprocessing(prompt) for prompt in prompt_segment_list] + + # 4. merge to obtain the final prompt + batch_prompts = [] + for prompt_segment_list, loop_idx_list in zip(batched_prompt_segment_list, batched_loop_idx_list): + batch_prompts.append(merge_prompt(prompt_segment_list, loop_idx_list)) + + # == Iter over loop generation == + video_clips = [] + for loop_i in range(loop): + # == get prompt for loop i == + extract_prompts_loop(batch_prompts, loop_i) + + # == add condition frames for loop == + if loop_i > 0: + refs, ms = append_generated( + vae, video_clips[-1], refs, ms, loop_i, condition_frame_length, condition_frame_edit + ) + if is_main_process(): + torch.save( + refs[i][loop_i].cpu(), os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ref.pt") + ) + with open(os.path.join(save_dir, cur_date, saved_idx, f"{i}_{loop_i}_ms"), "w") as f: + f.write(ms[i]) + + start_idx += len(batch_prompts) + logger.info("Inference VAE encoder finished.") + logger.info("Saved %s samples to %s", start_idx, save_dir) + + +if __name__ == "__main__": + main()