From 942346aa8cd192300ad9157b72c7de27dad5e3cd Mon Sep 17 00:00:00 2001 From: "zhexin.lzx" Date: Fri, 21 Jun 2024 15:29:29 +0800 Subject: [PATCH] [feat]: support navit --- .gitignore | 45 - README.md | 484 ++------ README_original.md | 410 +++++++ opensora/dataset/__init__.py | 10 +- opensora/dataset/transform.py | 31 + opensora/models/ae/videobase/modules/quant.py | 2 + opensora/models/diffusion/__init__.py | 2 + .../models/diffusion/diffusion/__init__.py | 41 +- .../models/diffusion/diffusion/respace.py | 65 +- .../diffusion/latte/modeling_latte_navit.py | 1053 +++++++++++++++++ opensora/models/diffusion/latte/modules.py | 15 +- opensora/train/train_t2v_navit.py | 784 ++++++++++++ opensora/utils/dataset_utils.py | 81 +- .../text_condition/train_videoae_65_navit.sh | 38 + .../train_videoae_65_navit_test.sh | 37 + scripts/train_data/image_data_debug.txt | 1 + scripts/train_data/video_data_debug.txt | 1 + tests/test_navit_consistency.py | 622 ++++++++++ 18 files changed, 3259 insertions(+), 463 deletions(-) create mode 100644 README_original.md create mode 100644 opensora/models/diffusion/latte/modeling_latte_navit.py create mode 100644 opensora/train/train_t2v_navit.py create mode 100644 scripts/text_condition/train_videoae_65_navit.sh create mode 100644 scripts/text_condition/train_videoae_65_navit_test.sh create mode 100644 scripts/train_data/image_data_debug.txt create mode 100644 scripts/train_data/video_data_debug.txt create mode 100644 tests/test_navit_consistency.py diff --git a/.gitignore b/.gitignore index 48d7e77d4..e69de29bb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,45 +0,0 @@ -ucf101_stride4x4x4 -__pycache__ -*.mp4 -.ipynb_checkpoints -*.pth -UCF-101/ -results/ -vae -build/ -opensora.egg-info/ -wandb/ -.idea -*.ipynb -*.jpg -*.mp3 -*.safetensors -*.mp4 -*.png -*.gif -*.pth -*.pt -cache_dir/ -wandb/ -test* -sample_video* -sample_image* -512* -720* -1024* -debug* -private* -caption* -*deepspeed* -revised* -129f* -all* -read* -YSH* -*pick* -*ysh* -hw* -257f* -513f* -taming* -221hw* \ No newline at end of file diff --git a/README.md b/README.md index e363bfff4..144b00f7a 100644 --- a/README.md +++ b/README.md @@ -1,410 +1,74 @@ -# Open-Sora Plan - - - -[![slack badge](https://img.shields.io/badge/Discord-join-blueviolet?logo=discord&)](https://discord.gg/YtsBNg7n) -[![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues/53#issuecomment-1987226516) -[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0) -[![Twitter](https://img.shields.io/badge/-Twitter@LinBin46984-black?logo=twitter&logoColor=1D9BF0)](https://x.com/LinBin46984/status/1795018003345510687)
-[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0) -[![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/LICENSE) -[![GitHub repo contributors](https://img.shields.io/github/contributors-anon/PKU-YuanGroup/Open-Sora-Plan?style=flat&label=Contributors)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/graphs/contributors) -[![GitHub Commit](https://img.shields.io/github/commit-activity/m/PKU-YuanGroup/Open-Sora-Plan?label=Commit)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/commits/main/) -[![Pr](https://img.shields.io/github/issues-pr-closed-raw/PKU-YuanGroup/Open-Sora-Plan.svg?label=Merged+PRs&color=green)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) -[![GitHub issues](https://img.shields.io/github/issues/PKU-YuanGroup/Open-Sora-Plan?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aopen+is%3Aissue) -[![GitHub closed issues](https://img.shields.io/github/issues-closed/PKU-YuanGroup/Open-Sora-Plan?color=success&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
-[![GitHub repo stars](https://img.shields.io/github/stars/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/stargazers)  -[![GitHub repo forks](https://img.shields.io/github/forks/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Forks)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/network)  -[![GitHub repo watchers](https://img.shields.io/github/watchers/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Watchers)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/watchers)  -[![GitHub repo size](https://img.shields.io/github/repo-size/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Repo%20Size)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/archive/refs/heads/main.zip) - -
-v1.0.0 badge -[![Twitter](https://img.shields.io/badge/-Twitter@LinBin46984-black?logo=twitter&logoColor=1D9BF0)](https://x.com/LinBin46984/status/1763476690385424554?s=20)
-[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0) -[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/fffiloni/Open-Sora-Plan-v1-0-0) -[![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb)
-
- -We are thrilled to present **Open-Sora-Plan v1.1.0**, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.1.0.md). We show compressed .gif on GitHub, which loses some quality. - -Thanks to **HUAWEI Ascend Team** for supporting us. In the second stage, we used Huawei Ascend computing power for training. This stage's training and inference were fully supported by Huawei. Models trained on Huawei Ascend can also be loaded into GPUs and generate videos of the same quality. - -目前已经支持使用国产AI芯片(华为昇腾,期待更多国产算力芯片)进行完整的训练和推理。在项目第二阶段,所有训练和推理任务完全由华为昇腾芯片支持。此外,基于华为昇腾的512卡集群训练出的模型,也可以无缝地在GPU上运行,并保持相同的视频质量。详细信息请参考我们的[hw branch](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/hw). - - -### 221×512×512 Text-to-Video Generation - - - -| 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) | -| --- | --- | --- | -| | | | -| 3D animation of a small, round, fluffy creature with big, expressive eyes explores ... | A single drop of liquid metal falls from a floating orb, landing on a mirror-like ... | The video presents an abstract composition centered around a hexagonal shape adorned ... | -| | | | -| A drone camera circles around a beautiful historic church built on a rocky outcropping ... | Aerial view of Santorini during the blue hour, showcasing the stunning architecture ... | An aerial shot of a lighthouse standing tall on a rocky cliff, its beacon cutting ... | -| | | | -| A snowy forest landscape with a dirt road running through it. The road is flanked by ... | Drone shot along the Hawaii jungle coastline, sunny day. Kayaks in the water. |The camera rotates around a large stack of vintage televisions all showing different ... | - - -### 65×512×512 Text-to-Video Generation - -| 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) | -| --- | --- | --- | -| | | | -| In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two ... | A Shiba Inu dog wearing a beret and black turtleneck. | A painting of a boat on water comes to life, with waves crashing and the boat becoming ... | -|| | | -| A person clad in a space suit with a helmet and equipped with a chest light and arm ... | 3D animation of a small, round, fluffy creature with big, expressive eyes explores a ... | In a studio, there is a painting depicting a ship sailing through the rough sea. | -| | | | -| A robot dog trots down a deserted alley at night, its metallic paws clinking softly ... | A lone surfer rides a massive wave, skillfully maneuvering through the surf. The water ... | A solitary cheetah sprints across the savannah, its powerful muscles propelling it ... | - -### 65×512×512 Video Editing - -| generated 65×512×512 (2.7s) | edited 65×512×512 (2.7s) | -| --- | --- | -| | | -| | | -| | | - -### 512×512 Text-to-Image Generation - - - - - - -## 📰 News - -**[2024.05.27]** 🚀🚀🚀 We are launching Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out our latest [report](docs/Report-v1.1.0.md). - -**[2024.04.09]** 🚀 Excited to share our latest exploration on metamorphic time-lapse video generation: [MagicTime](https://github.com/PKU-YuanGroup/MagicTime), which learns real-world physics knowledge from time-lapse videos. Here is the dataset for train (updating): [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset). - -**[2024.04.07]** 🔥🔥🔥 Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.0.0.md). Thanks to HUAWEI NPU for supporting us. - -**[2024.03.27]** 🚀🚀🚀 We release the report of [VideoCausalVAE](docs/CausalVideoVAE.md), which supports both images and videos. We present our reconstructed video in this demonstration as follows. The text-to-video model is on the way. - -
-View more - -**[2024.03.10]** 🚀🚀🚀 This repo supports training a latent size of 225×90×90 (t×h×w), which means we are able to **train 1 minute of 1080P video with 30FPS** (2× interpolated frames and 2× super resolution) under class-condition. - -**[2024.03.08]** We support the training code of text condition with 16 frames of 512x512. The code is mainly borrowed from [Latte](https://github.com/Vchitect/Latte). - -**[2024.03.07]** We support training with 128 frames (when sample rate = 3, which is about 13 seconds) of 256x256, or 64 frames (which is about 6 seconds) of 512x512. - -**[2024.03.05]** See our latest [todo](https://github.com/PKU-YuanGroup/Open-Sora-Plan?tab=readme-ov-file#todo), pull requests are welcome. - -**[2024.03.04]** We re-organize and modulize our code to make it easy to [contribute](https://github.com/PKU-YuanGroup/Open-Sora-Plan?tab=readme-ov-file#how-to-contribute-to-the-open-sora-plan-community) to the project, to contribute please see the [Repo structure](https://github.com/PKU-YuanGroup/Open-Sora-Plan?tab=readme-ov-file#repo-structure). - -**[2024.03.03]** We open some [discussions](https://github.com/PKU-YuanGroup/Open-Sora-Plan/discussions) to clarify several issues. - -**[2024.03.01]** Training code is available now! Learn more on our [project page](https://pku-yuangroup.github.io/Open-Sora-Plan/). Please feel free to watch 👀 this repository for the latest updates. - -
- -## 💪 Goal -This project aims to create a simple and scalable repo, to reproduce [Sora](https://openai.com/sora) (OpenAI, but we prefer to call it "ClosedAI" ). We wish the open-source community can contribute to this project. Pull requests are welcome!!! - -本项目希望通过开源社区的力量复现Sora,由北大-兔展AIGC联合实验室共同发起,当前版本离目标差距仍然较大,仍需持续完善和快速迭代,欢迎Pull request!!! - -Project stages: -- Primary -1. Setup the codebase and train an un-conditional model on a landscape dataset. -2. Train models that boost resolution and duration. - -- Extensions -3. Conduct text2video experiments on landscape dataset. -4. Train the 1080p model on video2text dataset. -5. Control model with more conditions. - - -
- - -
- - -
-✊ Todo - -#### Setup the codebase and train an unconditional model on landscape dataset -- [x] Fix typos & Update readme. 🤝 Thanks to [@mio2333](https://github.com/mio2333), [@CreamyLong](https://github.com/CreamyLong), [@chg0901](https://github.com/chg0901), [@Nyx-177](https://github.com/Nyx-177), [@HowardLi1984](https://github.com/HowardLi1984), [@sennnnn](https://github.com/sennnnn), [@Jason-fan20](https://github.com/Jason-fan20) -- [x] Setup environment. 🤝 Thanks to [@nameless1117](https://github.com/nameless1117) -- [ ] Add docker file. ⌛ [WIP] 🤝 Thanks to [@Mon-ius](https://github.com/Mon-ius), [@SimonLeeGit](https://github.com/SimonLeeGit) -- [ ] Enable type hints for functions. 🤝 Thanks to [@RuslanPeresy](https://github.com/RuslanPeresy), 🙏 **[Need your contribution]** -- [x] Resume from checkpoint. -- [x] Add Video-VQVAE model, which is borrowed from [VideoGPT](https://github.com/wilson1yan/VideoGPT). -- [x] Support variable aspect ratios, resolutions, durations training on [DiT](https://github.com/facebookresearch/DiT). -- [x] Support Dynamic mask input inspired by [FiT](https://github.com/whlzy/FiT). -- [x] Add class-conditioning on embeddings. -- [x] Incorporating [Latte](https://github.com/Vchitect/Latte) as main codebase. -- [x] Add VAE model, which is borrowed from [Stable Diffusion](https://github.com/CompVis/latent-diffusion). -- [x] Joint dynamic mask input with VAE. -- [ ] Add VQVAE from [VQGAN](https://github.com/CompVis/taming-transformers). 🙏 **[Need your contribution]** -- [ ] Make the codebase ready for the cluster training. Add SLURM scripts. 🙏 **[Need your contribution]** -- [x] Refactor VideoGPT. 🤝 Thanks to [@qqingzheng](https://github.com/qqingzheng), [@luo3300612](https://github.com/luo3300612), [@sennnnn](https://github.com/sennnnn) -- [x] Add sampling script. -- [ ] Add DDP sampling script. ⌛ [WIP] -- [x] Use accelerate on multi-node. 🤝 Thanks to [@sysuyy](https://github.com/sysuyy) -- [x] Incorporate [SiT](https://github.com/willisma/SiT). 🤝 Thanks to [@khan-yin](https://github.com/khan-yin) -- [x] Add evaluation scripts (FVD, CLIP score). 🤝 Thanks to [@rain305f](https://github.com/rain305f) - -#### Train models that boost resolution and duration -- [x] Add [PI](https://arxiv.org/abs/2306.15595) to support out-of-domain size. 🤝 Thanks to [@jpthu17](https://github.com/jpthu17) -- [x] Add 2D RoPE to improve generalization ability as [FiT](https://github.com/whlzy/FiT). 🤝 Thanks to [@jpthu17](https://github.com/jpthu17) -- [x] Compress KV according to [PixArt-sigma](https://pixart-alpha.github.io/PixArt-sigma-project). -- [x] Support deepspeed for videogpt training. 🤝 Thanks to [@sennnnn](https://github.com/sennnnn) -- [x] Train a **low dimension** Video-AE, whether it is VAE or VQVAE. -- [x] Extract offline feature. -- [x] Train with offline feature. -- [x] Add frame interpolation model. 🤝 Thanks to [@yunyangge](https://github.com/yunyangge) -- [x] Add super resolution model. 🤝 Thanks to [@Linzy19](https://github.com/Linzy19) -- [x] Add accelerate to automatically manage training. -- [x] Joint training with images. -- [ ] Implement [MaskDiT](https://github.com/Anima-Lab/MaskDiT) technique for fast training. 🙏 **[Need your contribution]** -- [ ] Incorporate [NaViT](https://arxiv.org/abs/2307.06304). 🙏 **[Need your contribution]** -- [ ] Add [FreeNoise](https://github.com/arthur-qiu/FreeNoise-LaVie) support for training-free longer video generation. 🙏 **[Need your contribution]** - -#### Conduct text2video experiments on landscape dataset. -- [x] Load pretrained weights from [Latte](https://github.com/Vchitect/Latte). -- [ ] Implement [PeRFlow](https://github.com/magic-research/piecewise-rectified-flow) for improving the sampling process. 🙏 **[Need your contribution]** -- [x] Finish data loading, pre-processing utils. -- [x] Add T5 support. -- [x] Add CLIP support. 🤝 Thanks to [@Ytimed2020](https://github.com/Ytimed2020) -- [x] Add text2image training script. -- [ ] Add prompt captioner. - - [ ] Collect training data. - - [ ] Need video-text pairs with caption. 🙏 **[Need your contribution]** - - [ ] Extract multi-frame descriptions by large image-language models. 🤝 Thanks to [@HowardLi1984](https://github.com/HowardLi1984) - - [ ] Extract video description by large video-language models. 🙏 **[Need your contribution]** - - [ ] Integrate captions to get a dense caption by using a large language model, such as GPT-4. 🤝 Thanks to [@HowardLi1984](https://github.com/HowardLi1984) - - [ ] Train a captioner to refine captions. 🚀 **[Require more computation]** - -#### Train the 1080p model on video2text dataset -- [ ] Looking for a suitable dataset, welcome to discuss and recommend. 🙏 **[Need your contribution]** -- [ ] Add synthetic video created by game engines or 3D representations. 🙏 **[Need your contribution]** -- [x] Finish data loading, and pre-processing utils. -- [x] Support memory friendly training. - - [x] Add flash-attention2 from pytorch. - - [x] Add xformers. 🤝 Thanks to [@jialin-zhao](https://github.com/jialin-zhao) - - [x] Support mixed precision training. - - [x] Add gradient checkpoint. - - [x] Support for ReBased and Ring attention. 🤝 Thanks to [@kabachuha](https://github.com/kabachuha) - - [x] Train using the deepspeed engine. 🤝 Thanks to [@sennnnn](https://github.com/sennnnn) -- [ ] Train with a text condition. Here we could conduct different experiments: 🚀 **[Require more computation]** - - [x] Train with T5 conditioning. - - [ ] Train with CLIP conditioning. - - [ ] Train with CLIP + T5 conditioning (probably costly during training and experiments). -- [ ] Support Chinese. ⌛ [WIP] - -#### Control model with more condition -- [ ] Incorporating [ControlNet](https://github.com/lllyasviel/ControlNet). ⌛ [WIP] 🙏 **[Need your contribution]** -- [ ] Incorporating [ReVideo](https://github.com/MC-E/ReVideo). ⌛ [WIP] - -
- -## 📂 Repo structure (WIP) -``` -├── README.md -├── docs -│ ├── Data.md -> Datasets description. -│ ├── Contribution_Guidelines.md -> Contribution guidelines description. -├── scripts -> All scripts. -├── opensora -│   ├── dataset -│   ├── models -│   │   ├── ae -> Compress videos to latents -│   │   │   ├── imagebase -│   │   │   │   ├── vae -│   │   │   │   └── vqvae -│   │   │   └── videobase -│   │   │   ├── vae -│   │   │   └── vqvae -│   │   ├── captioner -│   │   ├── diffusion -> Denoise latents -│   │   │   ├── diffusion -│   │   │   ├── dit -│   │   │   ├── latte -│   │   │   └── unet -│   │   ├── frame_interpolation -│   │   ├── super_resolution -│   │   └── text_encoder -│   ├── sample -│   ├── train -> Training code -│   └── utils -``` - -## 🛠️ Requirements and Installation - -1. Clone this repository and navigate to Open-Sora-Plan folder -``` -git clone https://github.com/PKU-YuanGroup/Open-Sora-Plan -cd Open-Sora-Plan -``` -2. Install required packages -``` -conda create -n opensora python=3.8 -y -conda activate opensora -pip install -e . -``` -3. Install additional packages for training cases -``` -pip install -e ".[train]" -pip install flash-attn --no-build-isolation -``` -4. Install optional requirements such as static type checking: -``` -pip install -e '.[dev]' -``` - -## 🗝️ Usage - - -### 🤗 Demo - -#### Gradio Web UI - -Highly recommend trying out our web demo by the following command. We also provide [online demo](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0) [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0). - -
-v1.0.0 - -Highly recommend trying out our web demo by the following command. We also provide [online demo](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0) [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0) and [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/fffiloni/Open-Sora-Plan-v1-0-0) in Huggingface Spaces. - -🤝 Enjoying the [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research! - -
- -```bash -python -m opensora.serve.gradio_web_server -``` - -#### CLI Inference - -```bash -sh scripts/text_condition/sample_video.sh -``` - -### Datasets -Refer to [Data.md](docs/Data.md) - -### Evaluation -Refer to the document [EVAL.md](docs/EVAL.md). - -### CausalVideoVAE - -#### Reconstructing - -Example: - -```Python -python examples/rec_imvi_vae.py --video_path test_video.mp4 --rec_path output_video.mp4 --fps 24 --resolution 512 --crop_size 512 --num_frames 128 --sample_rate 1 --ae CausalVAEModel_4x8x8 --model_path pretrained_488_release --enable_tiling --enable_time_chunk -``` - -Parameter explanation: - -- `--enable_tiling`: This parameter is a flag to enable a tiling conv. - -#### Training and Eval - -Please refer to the document [CausalVideoVAE](docs/Train_And_Eval_CausalVideoVAE.md). - -### VideoGPT VQVAE - -Please refer to the document [VQVAE](docs/VQVAE.md). - -### Video Diffusion Transformer - -#### Training -``` -sh scripts/text_condition/train_videoae_65x512x512.sh -``` -``` -sh scripts/text_condition/train_videoae_221x512x512.sh -``` -``` -sh scripts/text_condition/train_videoae_513x512x512.sh -``` - - - -## 💡 How to Contribute to the Open-Sora Plan Community -We greatly appreciate your contributions to the Open-Sora Plan open-source community and helping us make it even better than it is now! - -For more details, please refer to the [Contribution Guidelines](docs/Contribution_Guidelines.md) - - - - -## 👍 Acknowledgement -* [Latte](https://github.com/Vchitect/Latte): The **main codebase** we built upon and it is an wonderful video generated model. -* [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis. -* [ShareGPT4Video](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4Video): Improving Video Understanding and Generation with Better Captions. -* [VideoGPT](https://github.com/wilson1yan/VideoGPT): Video Generation using VQ-VAE and Transformers. -* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers. -* [FiT](https://github.com/whlzy/FiT): Flexible Vision Transformer for Diffusion Model. -* [Positional Interpolation](https://arxiv.org/abs/2306.15595): Extending Context Window of Large Language Models via Positional Interpolation. - - -## 🔒 License -* See [LICENSE](LICENSE) for details. - - - - -## ✏️ Citing - -### BibTeX - -```bibtex -@software{pku_yuan_lab_and_tuzhan_ai_etc_2024_10948109, - author = {PKU-Yuan Lab and Tuzhan AI etc.}, - title = {Open-Sora-Plan}, - month = apr, - year = 2024, - publisher = {GitHub}, - doi = {10.5281/zenodo.10948109}, - url = {https://doi.org/10.5281/zenodo.10948109} -} -``` -### Latest DOI - -[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10948109.svg)](https://zenodo.org/records/10948109) - -## 🤝 Community contributors - - - - +# Variable-Resolution T2V Training with NaViT +## Introduction +In the [technical characteristics](https://openai.com/index/sora/) disclosed by SORA officials, SORA is trained using a Transformer-based model on visual data of varying durations, resolutions, and aspect ratios , which may be crucial for reproducing the effects of SORA. +> We represent videos and images as collections of smaller units of data called patches, each of which is akin to a token in GPT. By unifying how we represent data, we can train diffusion transformers on a wider range of visual data than was possible before, spanning different durations, resolutions and aspect ratios. + + +However, video data are usually of different resolutions, making it impossible to concatenate a number of training data into a batched sequence for parallel computation. The community has proposed the following solutions: + +1. Padding: Padding all videos in a batch to the length of the longest sequence, similar to the techniques commonly used in NLP. However, this introduces a lot of unnecessary computation, especially when there is a high variance in the resolution of training samples. +2. Bucketing: Dividing training samples into different "buckets" based on their resolution, then randomly drawing samples from the same bucket to compose a batch. This, however, disrupts the sampling mechanism of SGD and may affect model convergence. +3. NaViT: NaViT was proposed for image classification tasks to train with images of different resolutions. By packing short sequences of different lengths into one long sequence and using an attention mask to isolate the attention computation for different samples, NaViT significantly reduces the proportion of unnecessary padding while ensuring computational equivalence, thereby improving the efficiency of variable-resolution training. However, in text-to-video (t2v) tasks, due to the introduction of text conditions and temporal information, integrating NaViT becomes more complex. As of now, there is no precedent in the industry for implementing NaViT for video generation tasks. + +Therefore, we achieve a NaViT implementation specifically for Open-SoRA-Plan, hoping to make an effort to AIGC community. +## How to Use +1. Prepare datasets and install requirements for training following the original Open-Sora-Plan procedure. +2. Start your navit training +``` +sh scripts/text_condition/train_videoae_65_navit.sh +``` +3. You might want to test its numerical consistency with the original Open-Sora-Plan implementation. +``` +sh scripts/text_condition/train_videoae_65_navit_test.sh +``` +## Method +### Dataloader +In the preprocessing stage, we introduce RandomResize transform that randomly resizes inputs to a range between 64 and the original resolution, to simulate variable resolution training. This transform should be removed during training so that the videos are already in their original resolutions. + +Concurrently, we upscales the video to the nearest integer multiple of the product of the VAE's compression ratio and patch size. This adjustment guarantees that the resultant video latent representation is divisible by the patch size. For image-joint-training, we further resize the images to their corresponding video resolution. + +Furthermore, modifications to the DataLoader have been implemented to yield the list of Tensors rather than batching them together to support the samples with different resolutions. +### VAE encode +Due to the inherent limitations of the existing variational autoencoder (VAE) architecture, which does not accommodate serialized data formats, the VAE encoding step is necessarily positioned preceding the data packing process. Moreover, the VAE lacks the capability to concurrently handle inputs of different resolutions. Consequently, we loop over the individual video samples and use VAE to encode them sequentially. +### Video Packing +In Latte, patch embedding is achieved through a 2D convolutional layer (conv2d) followed by reshaping operations to yield a serialized output sequence. We also follow this procedure, except that we group and pack some videos with different resolutions into a single sequence. + +Currently, a simple grouping strategy is used where videos are sequentially grouped based on a maximum sequence length threshold. Specifically, when the combined length of grouped videos exceeds this threshold, a new group is started. Videos within the same group are concatenated into a single sequence. + +Each video within a group undergoes a shared-parameter conv2d operation followed by reshaping to a serialized format. The resulting sequences within the same group are then concatenated along the token dimension, effectively completing the video packing process. + +To facilitate distinction between tokens from different samples as well as the padding tokens during subsequent attention mechanisms, a token-wise labeling scheme is implemented. Here, unique identifiers (0, 1, 2, ...) denote tokens belonging to different videos, while padding tokens are consistently marked with -1. +![](https://intranetproxy.alipay.com/skylark/lark/0/2024/jpeg/124356420/1716884849686-cc58b714-86e3-4da8-a318-541c410e45d9.jpeg) +### Token drop +We incorporate token dropping proposed in NaViT, enabling us to control the sequence length after the packing process. This strategy involves selectively discarding certain tokens to manage the overall sequence length, which is particularly beneficial for managing computational complexity and potentially mitigating the effects of less informative tokens. +### Position Encoding packing +For videos of arbitrary resolutions and aspect ratios, we interpolate their positional encodings to a fixed resolution, which can be specified by the user. This standardization facilitates the utilization of pre-trained models that were trained at a specific resolution. + +Within the same group of videos, after interpolating their positional encodings to match the designated resolution, they are concatenated along the token dimension as well. +### Timestep embedding packing +Within a single batch, since videos have their individual timesteps during training, we pack these timesteps according to the videos' packing pattern. This approach ensures that timestep embeddings can effectively integrate into the attention operations. +### Text condition packing +Similarly to packing timesteps, text inputs are also packed in accordance with the video's established packing pattern. This method ensures the alignment between textual and visual information. + +Akin to the video token labeling, text tokens are also annotated. Each token is systematically tagged with identifiers (0, 1, 2, ...), distinguishing text originating from different videos. Consistently, padding tokens within the text sequences are labeled with -1. This token-level differentiation is crucial for maintaining the consistency of multi-sample processing in downstream attention-based mechanisms. + +Additionally, text data inherently includes a mask indicating positions where padding was introduced during tokenization. These tokens are also labeled with -1. For simplicity, we omit this in the illustration. +![](https://intranetproxy.alipay.com/skylark/lark/0/2024/jpeg/124356420/1718863469086-00c57918-7f8c-45ce-9c8c-ac3a56a040c2.jpeg) +### Attention mask +#### self-attention mask +The query, key, and value (Q, K, V) are all derived from the respective video in self-attention. Our attention mechanism is configured to compute exclusively the tokens belonging to the same video, while effectively filtering out padding tokens. We utilize pre-saved token labels to generate the mask. +![](https://intranetproxy.alipay.com/skylark/lark/0/2024/jpeg/124356420/1716894781927-fb25ea3f-93b6-42c0-b565-364e6ba5364c.jpeg) +#### cross-attention mask +In cross-attention, The query (Q) is generated from the video, while the key and value (K, V) are derived from the text. Our cross-attention mechanism is designed to compute attention only for the text and video tokens in the same video, excluding any padding tokens. To generate the cross-attention mask, we concurrently utilize the pre-saved labels for both video tokens and text tokens. +![](https://intranetproxy.alipay.com/skylark/lark/0/2024/jpeg/124356420/1716894782010-82000b1d-9079-4e36-ac11-9302934524ca.jpeg) +### Loss +The predicted noise from the model needs to be compared with the actual noise to compute the loss. One approach involves unpacking the noise sequence the model outputs back into the original video format according to the saved packing pattern, and then computing the loss with respect to the ground truth noise. However, due to the variable resolution of the original videos, it is not feasible to batch process them for loss computation in a parallel manner. + +Therefore, we alternatively pack the ground truth noise using the same packing pattern and compute the loss with the model's output sequence. Since the packed sequences have a fixed length, this allows for parallelized acceleration of the loss computation. +## Numerical Consistency +We tested the consistency between the outputs of the NaViT model and the original model on the raw dataset provided by Open-Sora-Plan. Both models takes the same input data (with a fixed resolution of 512x512) and identical Position Encoding parameters. The output of the original model was packed according to the same pattern and compared with the output sequences of NaViT. The test script is attached in the code for reproduction. +## Limitation +1. Currently, we use a naive sequential grouping strategy rather than an optimized greedy grouping strategy. +2. Due to the significant workload involved in adapting the KL loss for NaViT, only MSE loss is currently supported. +3. Packing has only been achieved for the spatial dimension, and does not support the input with different durations. +4. Since NaViT requires attention masks to achieve computational isolation between different videos in a sequence, flash attention cannot be used for acceleration. diff --git a/README_original.md b/README_original.md new file mode 100644 index 000000000..e363bfff4 --- /dev/null +++ b/README_original.md @@ -0,0 +1,410 @@ +# Open-Sora Plan + + + +[![slack badge](https://img.shields.io/badge/Discord-join-blueviolet?logo=discord&)](https://discord.gg/YtsBNg7n) +[![WeChat badge](https://img.shields.io/badge/微信-加入-green?logo=wechat&)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/issues/53#issuecomment-1987226516) +[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0) +[![Twitter](https://img.shields.io/badge/-Twitter@LinBin46984-black?logo=twitter&logoColor=1D9BF0)](https://x.com/LinBin46984/status/1795018003345510687)
+[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0) +[![License](https://img.shields.io/badge/License-MIT-yellow)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/LICENSE) +[![GitHub repo contributors](https://img.shields.io/github/contributors-anon/PKU-YuanGroup/Open-Sora-Plan?style=flat&label=Contributors)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/graphs/contributors) +[![GitHub Commit](https://img.shields.io/github/commit-activity/m/PKU-YuanGroup/Open-Sora-Plan?label=Commit)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/commits/main/) +[![Pr](https://img.shields.io/github/issues-pr-closed-raw/PKU-YuanGroup/Open-Sora-Plan.svg?label=Merged+PRs&color=green)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) +[![GitHub issues](https://img.shields.io/github/issues/PKU-YuanGroup/Open-Sora-Plan?color=critical&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aopen+is%3Aissue) +[![GitHub closed issues](https://img.shields.io/github/issues-closed/PKU-YuanGroup/Open-Sora-Plan?color=success&label=Issues)](https://github.com/PKU-YuanGroup/Video-LLaVA/issues?q=is%3Aissue+is%3Aclosed)
+[![GitHub repo stars](https://img.shields.io/github/stars/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Stars)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/stargazers)  +[![GitHub repo forks](https://img.shields.io/github/forks/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Forks)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/network)  +[![GitHub repo watchers](https://img.shields.io/github/watchers/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Watchers)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/watchers)  +[![GitHub repo size](https://img.shields.io/github/repo-size/PKU-YuanGroup/Open-Sora-Plan?style=flat&logo=github&logoColor=whitesmoke&label=Repo%20Size)](https://github.com/PKU-YuanGroup/Open-Sora-Plan/archive/refs/heads/main.zip) + +
+v1.0.0 badge +[![Twitter](https://img.shields.io/badge/-Twitter@LinBin46984-black?logo=twitter&logoColor=1D9BF0)](https://x.com/LinBin46984/status/1763476690385424554?s=20)
+[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0) +[![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/fffiloni/Open-Sora-Plan-v1-0-0) +[![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb)
+
+ +We are thrilled to present **Open-Sora-Plan v1.1.0**, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.1.0.md). We show compressed .gif on GitHub, which loses some quality. + +Thanks to **HUAWEI Ascend Team** for supporting us. In the second stage, we used Huawei Ascend computing power for training. This stage's training and inference were fully supported by Huawei. Models trained on Huawei Ascend can also be loaded into GPUs and generate videos of the same quality. + +目前已经支持使用国产AI芯片(华为昇腾,期待更多国产算力芯片)进行完整的训练和推理。在项目第二阶段,所有训练和推理任务完全由华为昇腾芯片支持。此外,基于华为昇腾的512卡集群训练出的模型,也可以无缝地在GPU上运行,并保持相同的视频质量。详细信息请参考我们的[hw branch](https://github.com/PKU-YuanGroup/Open-Sora-Plan/tree/hw). + + +### 221×512×512 Text-to-Video Generation + + + +| 221×512×512 (9.2s) | 221×512×512 (9.2s) | 221×512×512 (9.2s) | +| --- | --- | --- | +| | | | +| 3D animation of a small, round, fluffy creature with big, expressive eyes explores ... | A single drop of liquid metal falls from a floating orb, landing on a mirror-like ... | The video presents an abstract composition centered around a hexagonal shape adorned ... | +| | | | +| A drone camera circles around a beautiful historic church built on a rocky outcropping ... | Aerial view of Santorini during the blue hour, showcasing the stunning architecture ... | An aerial shot of a lighthouse standing tall on a rocky cliff, its beacon cutting ... | +| | | | +| A snowy forest landscape with a dirt road running through it. The road is flanked by ... | Drone shot along the Hawaii jungle coastline, sunny day. Kayaks in the water. |The camera rotates around a large stack of vintage televisions all showing different ... | + + +### 65×512×512 Text-to-Video Generation + +| 65×512×512 (2.7s) | 65×512×512 (2.7s) | 65×512×512 (2.7s) | +| --- | --- | --- | +| | | | +| In an ornate, historical hall, a massive tidal wave peaks and begins to crash. Two ... | A Shiba Inu dog wearing a beret and black turtleneck. | A painting of a boat on water comes to life, with waves crashing and the boat becoming ... | +|| | | +| A person clad in a space suit with a helmet and equipped with a chest light and arm ... | 3D animation of a small, round, fluffy creature with big, expressive eyes explores a ... | In a studio, there is a painting depicting a ship sailing through the rough sea. | +| | | | +| A robot dog trots down a deserted alley at night, its metallic paws clinking softly ... | A lone surfer rides a massive wave, skillfully maneuvering through the surf. The water ... | A solitary cheetah sprints across the savannah, its powerful muscles propelling it ... | + +### 65×512×512 Video Editing + +| generated 65×512×512 (2.7s) | edited 65×512×512 (2.7s) | +| --- | --- | +| | | +| | | +| | | + +### 512×512 Text-to-Image Generation + + + + + + +## 📰 News + +**[2024.05.27]** 🚀🚀🚀 We are launching Open-Sora Plan v1.1.0, which significantly improves video quality and length, and is fully open source! Please check out our latest [report](docs/Report-v1.1.0.md). + +**[2024.04.09]** 🚀 Excited to share our latest exploration on metamorphic time-lapse video generation: [MagicTime](https://github.com/PKU-YuanGroup/MagicTime), which learns real-world physics knowledge from time-lapse videos. Here is the dataset for train (updating): [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset). + +**[2024.04.07]** 🔥🔥🔥 Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities. See our [report](docs/Report-v1.0.0.md). Thanks to HUAWEI NPU for supporting us. + +**[2024.03.27]** 🚀🚀🚀 We release the report of [VideoCausalVAE](docs/CausalVideoVAE.md), which supports both images and videos. We present our reconstructed video in this demonstration as follows. The text-to-video model is on the way. + +
+View more + +**[2024.03.10]** 🚀🚀🚀 This repo supports training a latent size of 225×90×90 (t×h×w), which means we are able to **train 1 minute of 1080P video with 30FPS** (2× interpolated frames and 2× super resolution) under class-condition. + +**[2024.03.08]** We support the training code of text condition with 16 frames of 512x512. The code is mainly borrowed from [Latte](https://github.com/Vchitect/Latte). + +**[2024.03.07]** We support training with 128 frames (when sample rate = 3, which is about 13 seconds) of 256x256, or 64 frames (which is about 6 seconds) of 512x512. + +**[2024.03.05]** See our latest [todo](https://github.com/PKU-YuanGroup/Open-Sora-Plan?tab=readme-ov-file#todo), pull requests are welcome. + +**[2024.03.04]** We re-organize and modulize our code to make it easy to [contribute](https://github.com/PKU-YuanGroup/Open-Sora-Plan?tab=readme-ov-file#how-to-contribute-to-the-open-sora-plan-community) to the project, to contribute please see the [Repo structure](https://github.com/PKU-YuanGroup/Open-Sora-Plan?tab=readme-ov-file#repo-structure). + +**[2024.03.03]** We open some [discussions](https://github.com/PKU-YuanGroup/Open-Sora-Plan/discussions) to clarify several issues. + +**[2024.03.01]** Training code is available now! Learn more on our [project page](https://pku-yuangroup.github.io/Open-Sora-Plan/). Please feel free to watch 👀 this repository for the latest updates. + +
+ +## 💪 Goal +This project aims to create a simple and scalable repo, to reproduce [Sora](https://openai.com/sora) (OpenAI, but we prefer to call it "ClosedAI" ). We wish the open-source community can contribute to this project. Pull requests are welcome!!! + +本项目希望通过开源社区的力量复现Sora,由北大-兔展AIGC联合实验室共同发起,当前版本离目标差距仍然较大,仍需持续完善和快速迭代,欢迎Pull request!!! + +Project stages: +- Primary +1. Setup the codebase and train an un-conditional model on a landscape dataset. +2. Train models that boost resolution and duration. + +- Extensions +3. Conduct text2video experiments on landscape dataset. +4. Train the 1080p model on video2text dataset. +5. Control model with more conditions. + + +
+ + +
+ + +
+✊ Todo + +#### Setup the codebase and train an unconditional model on landscape dataset +- [x] Fix typos & Update readme. 🤝 Thanks to [@mio2333](https://github.com/mio2333), [@CreamyLong](https://github.com/CreamyLong), [@chg0901](https://github.com/chg0901), [@Nyx-177](https://github.com/Nyx-177), [@HowardLi1984](https://github.com/HowardLi1984), [@sennnnn](https://github.com/sennnnn), [@Jason-fan20](https://github.com/Jason-fan20) +- [x] Setup environment. 🤝 Thanks to [@nameless1117](https://github.com/nameless1117) +- [ ] Add docker file. ⌛ [WIP] 🤝 Thanks to [@Mon-ius](https://github.com/Mon-ius), [@SimonLeeGit](https://github.com/SimonLeeGit) +- [ ] Enable type hints for functions. 🤝 Thanks to [@RuslanPeresy](https://github.com/RuslanPeresy), 🙏 **[Need your contribution]** +- [x] Resume from checkpoint. +- [x] Add Video-VQVAE model, which is borrowed from [VideoGPT](https://github.com/wilson1yan/VideoGPT). +- [x] Support variable aspect ratios, resolutions, durations training on [DiT](https://github.com/facebookresearch/DiT). +- [x] Support Dynamic mask input inspired by [FiT](https://github.com/whlzy/FiT). +- [x] Add class-conditioning on embeddings. +- [x] Incorporating [Latte](https://github.com/Vchitect/Latte) as main codebase. +- [x] Add VAE model, which is borrowed from [Stable Diffusion](https://github.com/CompVis/latent-diffusion). +- [x] Joint dynamic mask input with VAE. +- [ ] Add VQVAE from [VQGAN](https://github.com/CompVis/taming-transformers). 🙏 **[Need your contribution]** +- [ ] Make the codebase ready for the cluster training. Add SLURM scripts. 🙏 **[Need your contribution]** +- [x] Refactor VideoGPT. 🤝 Thanks to [@qqingzheng](https://github.com/qqingzheng), [@luo3300612](https://github.com/luo3300612), [@sennnnn](https://github.com/sennnnn) +- [x] Add sampling script. +- [ ] Add DDP sampling script. ⌛ [WIP] +- [x] Use accelerate on multi-node. 🤝 Thanks to [@sysuyy](https://github.com/sysuyy) +- [x] Incorporate [SiT](https://github.com/willisma/SiT). 🤝 Thanks to [@khan-yin](https://github.com/khan-yin) +- [x] Add evaluation scripts (FVD, CLIP score). 🤝 Thanks to [@rain305f](https://github.com/rain305f) + +#### Train models that boost resolution and duration +- [x] Add [PI](https://arxiv.org/abs/2306.15595) to support out-of-domain size. 🤝 Thanks to [@jpthu17](https://github.com/jpthu17) +- [x] Add 2D RoPE to improve generalization ability as [FiT](https://github.com/whlzy/FiT). 🤝 Thanks to [@jpthu17](https://github.com/jpthu17) +- [x] Compress KV according to [PixArt-sigma](https://pixart-alpha.github.io/PixArt-sigma-project). +- [x] Support deepspeed for videogpt training. 🤝 Thanks to [@sennnnn](https://github.com/sennnnn) +- [x] Train a **low dimension** Video-AE, whether it is VAE or VQVAE. +- [x] Extract offline feature. +- [x] Train with offline feature. +- [x] Add frame interpolation model. 🤝 Thanks to [@yunyangge](https://github.com/yunyangge) +- [x] Add super resolution model. 🤝 Thanks to [@Linzy19](https://github.com/Linzy19) +- [x] Add accelerate to automatically manage training. +- [x] Joint training with images. +- [ ] Implement [MaskDiT](https://github.com/Anima-Lab/MaskDiT) technique for fast training. 🙏 **[Need your contribution]** +- [ ] Incorporate [NaViT](https://arxiv.org/abs/2307.06304). 🙏 **[Need your contribution]** +- [ ] Add [FreeNoise](https://github.com/arthur-qiu/FreeNoise-LaVie) support for training-free longer video generation. 🙏 **[Need your contribution]** + +#### Conduct text2video experiments on landscape dataset. +- [x] Load pretrained weights from [Latte](https://github.com/Vchitect/Latte). +- [ ] Implement [PeRFlow](https://github.com/magic-research/piecewise-rectified-flow) for improving the sampling process. 🙏 **[Need your contribution]** +- [x] Finish data loading, pre-processing utils. +- [x] Add T5 support. +- [x] Add CLIP support. 🤝 Thanks to [@Ytimed2020](https://github.com/Ytimed2020) +- [x] Add text2image training script. +- [ ] Add prompt captioner. + - [ ] Collect training data. + - [ ] Need video-text pairs with caption. 🙏 **[Need your contribution]** + - [ ] Extract multi-frame descriptions by large image-language models. 🤝 Thanks to [@HowardLi1984](https://github.com/HowardLi1984) + - [ ] Extract video description by large video-language models. 🙏 **[Need your contribution]** + - [ ] Integrate captions to get a dense caption by using a large language model, such as GPT-4. 🤝 Thanks to [@HowardLi1984](https://github.com/HowardLi1984) + - [ ] Train a captioner to refine captions. 🚀 **[Require more computation]** + +#### Train the 1080p model on video2text dataset +- [ ] Looking for a suitable dataset, welcome to discuss and recommend. 🙏 **[Need your contribution]** +- [ ] Add synthetic video created by game engines or 3D representations. 🙏 **[Need your contribution]** +- [x] Finish data loading, and pre-processing utils. +- [x] Support memory friendly training. + - [x] Add flash-attention2 from pytorch. + - [x] Add xformers. 🤝 Thanks to [@jialin-zhao](https://github.com/jialin-zhao) + - [x] Support mixed precision training. + - [x] Add gradient checkpoint. + - [x] Support for ReBased and Ring attention. 🤝 Thanks to [@kabachuha](https://github.com/kabachuha) + - [x] Train using the deepspeed engine. 🤝 Thanks to [@sennnnn](https://github.com/sennnnn) +- [ ] Train with a text condition. Here we could conduct different experiments: 🚀 **[Require more computation]** + - [x] Train with T5 conditioning. + - [ ] Train with CLIP conditioning. + - [ ] Train with CLIP + T5 conditioning (probably costly during training and experiments). +- [ ] Support Chinese. ⌛ [WIP] + +#### Control model with more condition +- [ ] Incorporating [ControlNet](https://github.com/lllyasviel/ControlNet). ⌛ [WIP] 🙏 **[Need your contribution]** +- [ ] Incorporating [ReVideo](https://github.com/MC-E/ReVideo). ⌛ [WIP] + +
+ +## 📂 Repo structure (WIP) +``` +├── README.md +├── docs +│ ├── Data.md -> Datasets description. +│ ├── Contribution_Guidelines.md -> Contribution guidelines description. +├── scripts -> All scripts. +├── opensora +│   ├── dataset +│   ├── models +│   │   ├── ae -> Compress videos to latents +│   │   │   ├── imagebase +│   │   │   │   ├── vae +│   │   │   │   └── vqvae +│   │   │   └── videobase +│   │   │   ├── vae +│   │   │   └── vqvae +│   │   ├── captioner +│   │   ├── diffusion -> Denoise latents +│   │   │   ├── diffusion +│   │   │   ├── dit +│   │   │   ├── latte +│   │   │   └── unet +│   │   ├── frame_interpolation +│   │   ├── super_resolution +│   │   └── text_encoder +│   ├── sample +│   ├── train -> Training code +│   └── utils +``` + +## 🛠️ Requirements and Installation + +1. Clone this repository and navigate to Open-Sora-Plan folder +``` +git clone https://github.com/PKU-YuanGroup/Open-Sora-Plan +cd Open-Sora-Plan +``` +2. Install required packages +``` +conda create -n opensora python=3.8 -y +conda activate opensora +pip install -e . +``` +3. Install additional packages for training cases +``` +pip install -e ".[train]" +pip install flash-attn --no-build-isolation +``` +4. Install optional requirements such as static type checking: +``` +pip install -e '.[dev]' +``` + +## 🗝️ Usage + + +### 🤗 Demo + +#### Gradio Web UI + +Highly recommend trying out our web demo by the following command. We also provide [online demo](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0) [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.1.0). + +
+v1.0.0 + +Highly recommend trying out our web demo by the following command. We also provide [online demo](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0) [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0) and [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/fffiloni/Open-Sora-Plan-v1-0-0) in Huggingface Spaces. + +🤝 Enjoying the [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research! + +
+ +```bash +python -m opensora.serve.gradio_web_server +``` + +#### CLI Inference + +```bash +sh scripts/text_condition/sample_video.sh +``` + +### Datasets +Refer to [Data.md](docs/Data.md) + +### Evaluation +Refer to the document [EVAL.md](docs/EVAL.md). + +### CausalVideoVAE + +#### Reconstructing + +Example: + +```Python +python examples/rec_imvi_vae.py --video_path test_video.mp4 --rec_path output_video.mp4 --fps 24 --resolution 512 --crop_size 512 --num_frames 128 --sample_rate 1 --ae CausalVAEModel_4x8x8 --model_path pretrained_488_release --enable_tiling --enable_time_chunk +``` + +Parameter explanation: + +- `--enable_tiling`: This parameter is a flag to enable a tiling conv. + +#### Training and Eval + +Please refer to the document [CausalVideoVAE](docs/Train_And_Eval_CausalVideoVAE.md). + +### VideoGPT VQVAE + +Please refer to the document [VQVAE](docs/VQVAE.md). + +### Video Diffusion Transformer + +#### Training +``` +sh scripts/text_condition/train_videoae_65x512x512.sh +``` +``` +sh scripts/text_condition/train_videoae_221x512x512.sh +``` +``` +sh scripts/text_condition/train_videoae_513x512x512.sh +``` + + + +## 💡 How to Contribute to the Open-Sora Plan Community +We greatly appreciate your contributions to the Open-Sora Plan open-source community and helping us make it even better than it is now! + +For more details, please refer to the [Contribution Guidelines](docs/Contribution_Guidelines.md) + + + + +## 👍 Acknowledgement +* [Latte](https://github.com/Vchitect/Latte): The **main codebase** we built upon and it is an wonderful video generated model. +* [PixArt-alpha](https://github.com/PixArt-alpha/PixArt-alpha): Fast Training of Diffusion Transformer for Photorealistic Text-to-Image Synthesis. +* [ShareGPT4Video](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4Video): Improving Video Understanding and Generation with Better Captions. +* [VideoGPT](https://github.com/wilson1yan/VideoGPT): Video Generation using VQ-VAE and Transformers. +* [DiT](https://github.com/facebookresearch/DiT): Scalable Diffusion Models with Transformers. +* [FiT](https://github.com/whlzy/FiT): Flexible Vision Transformer for Diffusion Model. +* [Positional Interpolation](https://arxiv.org/abs/2306.15595): Extending Context Window of Large Language Models via Positional Interpolation. + + +## 🔒 License +* See [LICENSE](LICENSE) for details. + + + + +## ✏️ Citing + +### BibTeX + +```bibtex +@software{pku_yuan_lab_and_tuzhan_ai_etc_2024_10948109, + author = {PKU-Yuan Lab and Tuzhan AI etc.}, + title = {Open-Sora-Plan}, + month = apr, + year = 2024, + publisher = {GitHub}, + doi = {10.5281/zenodo.10948109}, + url = {https://doi.org/10.5281/zenodo.10948109} +} +``` +### Latest DOI + +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.10948109.svg)](https://zenodo.org/records/10948109) + +## 🤝 Community contributors + + + + diff --git a/opensora/dataset/__init__.py b/opensora/dataset/__init__.py index cb0d0f041..169f9d294 100644 --- a/opensora/dataset/__init__.py +++ b/opensora/dataset/__init__.py @@ -6,7 +6,7 @@ from torchvision.transforms import Lambda from .t2v_datasets import T2V_dataset -from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo +from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo, RandomResize ae_norm = { @@ -62,4 +62,12 @@ def getdataset(args): ]) tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer) + elif args.dataset == 't2v_navit': + transform = transforms.Compose([ + ToTensorVideo(), + RandomResize(), + norm_fun + ]) + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir) + return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer) raise NotImplementedError(args.dataset) diff --git a/opensora/dataset/transform.py b/opensora/dataset/transform.py index bb89c2c85..e72f9a9f9 100644 --- a/opensora/dataset/transform.py +++ b/opensora/dataset/transform.py @@ -477,6 +477,37 @@ def __call__(self, clip): def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})" + +class RandomResize: + """ + Resize the video randomly in (64, orig_size) + """ + def __init__( + self, + interpolation_mode="bilinear", + ): + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be resized. Size is (T, C, H, W) + Returns: + torch.tensor: Resized video clip. + size is (T, C, new_h, img_w) + """ + img_h,img_w = clip.shape[-2:] + if img_h > 64: + new_h = random.randint(64, img_h) + if img_w > 64: + new_w = random.randint(64, img_w) + + clip_resized = resize(clip, target_size=(new_h,new_w), + interpolation_mode=self.interpolation_mode) + return clip_resized + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(patch_size={self.patch_size}, interpolation_mode={self.interpolation_mode}" # ------------------------------------------------------------ diff --git a/opensora/models/ae/videobase/modules/quant.py b/opensora/models/ae/videobase/modules/quant.py index bb702cee5..dd561371a 100644 --- a/opensora/models/ae/videobase/modules/quant.py +++ b/opensora/models/ae/videobase/modules/quant.py @@ -1,3 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + import torch import torch.nn as nn import torch.distributed as dist diff --git a/opensora/models/diffusion/__init__.py b/opensora/models/diffusion/__init__.py index bf3abffa9..40520589e 100644 --- a/opensora/models/diffusion/__init__.py +++ b/opensora/models/diffusion/__init__.py @@ -1,7 +1,9 @@ from .latte.modeling_latte import Latte_models +from .latte.modeling_latte_navit import Latte_navit_models Diffusion_models = {} Diffusion_models.update(Latte_models) +Diffusion_models.update(Latte_navit_models) \ No newline at end of file diff --git a/opensora/models/diffusion/diffusion/__init__.py b/opensora/models/diffusion/diffusion/__init__.py index 04b2bd3d8..afc0a07b1 100644 --- a/opensora/models/diffusion/diffusion/__init__.py +++ b/opensora/models/diffusion/diffusion/__init__.py @@ -3,7 +3,7 @@ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T +from .respace import SpacedDiffusion, space_timesteps, SpacedDiffusion_T, NaViTSpacedDiffusion_T def create_diffusion( @@ -85,3 +85,42 @@ def create_diffusion_T( loss_type=loss_type # rescale_timesteps=rescale_timesteps, ) + +def create_diffusion_navit( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + learn_sigma=False, # NaViT only supports learn_sigma=False + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + from . import gaussian_diffusion_t2v as gd + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return NaViTSpacedDiffusion_T( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) \ No newline at end of file diff --git a/opensora/models/diffusion/diffusion/respace.py b/opensora/models/diffusion/diffusion/respace.py index aed6ed77f..f490f44de 100644 --- a/opensora/models/diffusion/diffusion/respace.py +++ b/opensora/models/diffusion/diffusion/respace.py @@ -7,7 +7,17 @@ import torch as th from .gaussian_diffusion import GaussianDiffusion -from .gaussian_diffusion_t2v import GaussianDiffusion_T + +from opensora.models.diffusion.latte.modeling_latte_navit import pack_target_as +from .gaussian_diffusion_t2v import ( + GaussianDiffusion_T, + LossType, + ModelMeanType, + ModelVarType, + discretized_gaussian_log_likelihood, + mean_flat, + normal_kl, +) def space_timesteps(num_timesteps, section_counts): @@ -195,4 +205,55 @@ def __call__(self, x, ts, **kwargs): new_ts = map_tensor[ts] # if self.rescale_timesteps: # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, new_ts, **kwargs) \ No newline at end of file + return self.model(x, new_ts, **kwargs) + + +class NaViTSpacedDiffusion_T(SpacedDiffusion_T): + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: List[C x ...] for videos of different resolution. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + model = self._wrap_model(model) + if model_kwargs is None: + model_kwargs = {} + # For video of different resolution case like NaViT training. + assert isinstance(x_start, list) + noise = list(map(th.randn_like, x_start)) + x_t = list(map(self.q_sample, *(x_start, t, noise))) + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + raise NotImplementedError("NaViT only supports `loss_type` == MSE") + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + # [b, F, T, p*p*c] + model_output, video_ids, token_kept_ids = model(x_t, t, **model_kwargs) + _, _, _, out_dim = model_output.shape + model_output, model_var_values = th.split(model_output, out_dim // 2, dim=-1) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + # TODO: support vb loss + raise NotImplementedError("NaViT only supports fixed `model_var_type`") + + + assert self.model_mean_type == ModelMeanType.EPSILON, "NaViT only supports `ModelMeanType.EPSILON`" + # [b, F, T, p*p*c] + target = pack_target_as(noise, video_ids, model.model.patch_size, token_kept_ids).to(model_output.dtype) + + assert model_output.shape == target.shape, f"{model_output.shape}, {target.shape}" + + terms["loss"] = torch.nn.functional.mse_loss(target, model_output) + else: + raise NotImplementedError(self.loss_type) + + return terms diff --git a/opensora/models/diffusion/latte/modeling_latte_navit.py b/opensora/models/diffusion/latte/modeling_latte_navit.py new file mode 100644 index 000000000..13ed309cd --- /dev/null +++ b/opensora/models/diffusion/latte/modeling_latte_navit.py @@ -0,0 +1,1053 @@ +import torch + +import os +import json + +from dataclasses import dataclass +from einops import rearrange, repeat +from diffusers.utils import USE_PEFT_BACKEND, deprecate +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from opensora.models.diffusion.utils.pos_embed import get_1d_sincos_pos_embed, get_2d_sincos_pos_embed, PositionGetter1D, PositionGetter2D +from opensora.models.diffusion.latte.modules import BasicTransformerBlock, BasicTransformerBlock_, AdaLayerNormSingle, CaptionProjection + +def video_grouping( + videos, patch_size, max_token_lim=4096, token_dropout_rate=0.0 +): + """ + Use greedy algorithms to group videos into groups with num_token less than max_token_lim + + Args: + videos: List of tensor with shape (F, C, H, W), F is the number of frames, + C is the number of channels, and H and W are the height and width respectively. + + Returns: + groups: A list of lists containing torch.Tensors, each having shape (F, C, H, W). + """ + + groups, video_ids = [[]], [[]] + + # greedy algorithm is a bit complex to implement so we use the naive + # sequential grouping for now. + # TODO: implement greedy algorithm + seq_len = 0 + for idx, video in enumerate(videos): + assert isinstance(video, torch.Tensor), "video must be torch.Tensor" + assert video.ndim == 4, "video must be 4d tensor" + + video_h, video_w = video.shape[-2:] + assert (video_w % patch_size) == 0 and ( + video_h % patch_size + ) == 0, f"video width and height must be divisible by patch size {patch_size}" + patch_w, patch_h = video_w // patch_size, video_h // patch_size + + token_len = int(patch_w * patch_h * (1 - token_dropout_rate)) + assert ( + token_len <= max_token_lim + ), f"token length {token_len} exceeds max token length {max_token_lim}" + if seq_len + token_len <= max_token_lim: + groups[-1].append(video) + video_ids[-1].append(idx) + seq_len += token_len + else: + groups.append([video]) + video_ids.append([idx]) + seq_len = token_len + + return groups, video_ids + + +def pack_timestep_as(timestep, video_ids, num_patches): + # timestep: (B, D) + batched_output = [] + for group, num_patch in zip(video_ids, num_patches): + output = torch.empty( + (0,), + device=timestep.device, + dtype=timestep.dtype, + ) + for sample_idx, t in zip(group, num_patch): + # sample: (D,) + sample = timestep[sample_idx] + # (T, D) + sample = repeat(sample.unsqueeze(0), "1 n -> (1 t) n", t=t) + output = torch.cat([output, sample], dim=0) + batched_output.append(output) + + # (b, T, D) + batched_output = nn.utils.rnn.pad_sequence( + batched_output, batch_first=True + ) + return batched_output + + +def pack_text_as(encoder_hidden_states, encoder_attention_mask, video_ids): + # encoder_hidden_states: (B, L, D) + # encoder_attention_mask: (B, L) + encoder_attention_mask = encoder_attention_mask.bool() + batched_output = [] + batched_idx = [] + for group in video_ids: + output = torch.empty( + (0,), + device=encoder_hidden_states.device, + dtype=encoder_hidden_states.dtype, + ) + text_idx = torch.empty( + (0,), + device=encoder_hidden_states.device, + dtype=torch.long, + ) + for idx, sample_idx in enumerate(group): + # (L, D) + sample = encoder_hidden_states[sample_idx] + # (L,) + padding_mask = encoder_attention_mask[sample_idx] + # (l, D) l <= L + # discard padded tokens before packing + non_padding_sample = sample[padding_mask] + output = torch.cat([output, non_padding_sample], dim=0) + text_idx = torch.cat( + ( + text_idx, + torch.full( + (non_padding_sample.shape[0],), + idx, + device=encoder_hidden_states.device, + dtype=torch.long, + ), # (l,) + ) + ) + batched_output.append(output) + batched_idx.append(text_idx) + + # (b, L', D) L' = max(sum(l)) + batched_output = nn.utils.rnn.pad_sequence( + batched_output, batch_first=True + ) + # (b, L') + batched_idx = nn.utils.rnn.pad_sequence( + batched_idx, batch_first=True, padding_value=-2 + ) + return batched_output, batched_idx + +def pack_image_joint_text_as(encoder_hidden_states, encoder_attention_mask, video_ids): + # encoder_hidden_states: (B, F, L, D) + # encoder_attention_mask: (B, F, L) + # 0: valid, -1: padding + encoder_attention_mask = encoder_attention_mask.bool() + batched_output = [] + batched_idx = [] + frame = encoder_hidden_states.shape[1] + for group in video_ids: + output = torch.empty( + (0,), + device=encoder_hidden_states.device, + dtype=encoder_hidden_states.dtype, + ) + text_idx = torch.empty( + (0,), + device=encoder_hidden_states.device, + dtype=torch.long, + ) + for idx, sample_idx in enumerate(group): + # (F, L, D) + sample = encoder_hidden_states[sample_idx] + # (F, L) + padding_mask = ~encoder_attention_mask[sample_idx] + # (F, L,) + cur_text_idx = torch.full( + tuple(padding_mask.shape), + idx, + device=encoder_hidden_states.device, + dtype=torch.long, + ).masked_fill(padding_mask, -2) + output = torch.cat([output, sample], dim=1) + text_idx = torch.cat( + ( + text_idx, + cur_text_idx, + ), + dim=1, + ) + batched_output.append(rearrange(output, "f l d -> l (f d)")) + batched_idx.append(text_idx.transpose(-1,-2)) + + # # (b, F, L', D) + # batched_output = torch.stack(batched_output) + # # (b, F, L') + # batched_idx = torch.stack(batched_idx) + + # (b, F, L', D) + batched_output = nn.utils.rnn.pad_sequence( + batched_output, batch_first=True + ) + batched_output = rearrange(batched_output, "b l (f d) -> b f l d", f=frame) + # (b, F, L') + batched_idx = nn.utils.rnn.pad_sequence( + batched_idx, batch_first=True, padding_value=-2 + ) + batched_idx = batched_idx.transpose(-1,-2) + return batched_output, batched_idx + + +def pack_target_as(targets, video_ids, patch_size, token_kept_ids=None): + # targets: List of (C, F, H, W) + device, dtype = targets[0].device, targets[0].dtype + num_frame = targets[0].shape[1] + batched_output = [] + for group_id, group in enumerate(video_ids): + output = torch.empty( + (0,), + device=device, + dtype=dtype, + ) + for idx, sample_idx in enumerate(group): + # (C, F, H, W) + target = targets[sample_idx] + height, width = target.shape[-2:] + num_patch_h, num_patch_w = ( + height // patch_size, + width // patch_size, + ) + # (T, F*p*p*c) + target = rearrange( + target, + "c f (h p) (w q)-> (h w) (f p q c)", + h=num_patch_h, + w=num_patch_w, + p=patch_size, + q=patch_size, + ) + if token_kept_ids is not None: + target = target[token_kept_ids[group_id][idx]] + output = torch.cat([output, target], dim=0) + batched_output.append(output) + + # (b, T, F*p*p*c) + batched_output = nn.utils.rnn.pad_sequence( + batched_output, batch_first=True + ) + # (b, F, T, D) + batched_output = rearrange( + batched_output, "b t (f d) -> b f t d", f=num_frame + ) + return batched_output + + +def mask_to_bias(attention_mask, dtype): + """ + convert attention_mask to a bias + expects mask of shape: + [batch, query_tokens, key_tokens] + this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + assume that mask is expressed as: + (True = keep, False = discard) + convert mask into a bias that can be added to attention scores: + (keep = +0, discard = -10000.0) + """ + return (1 - attention_mask.to(dtype)) * -10000.0 + + +class ExamplePacking(nn.Module): + """3D video to Patch Embedding""" + + def __init__( + self, + base_height=512, # 4096 -> vae x 8 down -> 512 + base_width=512, + patch_size=2, + in_channels=3, + embed_dim=768, + layer_norm=False, + bias=True, + interpolation_scale=1, + max_token_lim=1024, + token_dropout_rate=0.0, + ): + super().__init__() + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=patch_size, + bias=bias, + ) + if layer_norm: + self.norm = nn.LayerNorm( + embed_dim, elementwise_affine=False, eps=1e-6 + ) + else: + self.norm = None + + self.patch_size = patch_size + self.embed_dim = embed_dim + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + self.base_height, self.base_width = (base_height, base_width) + self.base_num_patches = (base_height // patch_size) * ( + base_width // patch_size + ) + self.base_size = base_height // patch_size + self.interpolation_scale = interpolation_scale + self.max_token_lim = max_token_lim + self.token_dropout_rate = token_dropout_rate + pos_embed = get_2d_sincos_pos_embed( + embed_dim, + int(self.base_num_patches**0.5), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + self.register_buffer( + "pos_embed", + torch.from_numpy(pos_embed), + persistent=False, + ) + + def forward(self, latent, dtype): + """ + Pack latent from different batches to one sequence and proj to patch embed. + Args: + latent: List of list of tensor with shape (F, C, H, W), where F is the number of frames, + C is the number of channels, and H and W are the height and width respectively. + + Returns: + output: Output tensor with shape (b, F, T, D) + """ + + num_frame = latent[0][0].shape[0] + device = latent[0][0].device + video_groups, video_ids = video_grouping( + latent, + self.patch_size, + max_token_lim=self.max_token_lim, + token_dropout_rate=self.token_dropout_rate, + ) + + batched_video = [] + batched_pos = [] + batched_idx = [] + + batched_len = [] + max_len = 0 + # group_size = [] # number of videos of each group + num_patches = [] # number of patches of each video in each seq + token_kept_ids = [] # token left after random dropping + + for group in video_groups: + # group_size.append(len(group)) + num_patches.append([]) + token_kept_ids.append([]) + video_seq = torch.empty( + (num_frame, 0, self.embed_dim), + device=device, + dtype=dtype, + ) + video_pos = torch.empty( + (0, self.embed_dim), device=device, dtype=dtype + ) + video_idx = torch.empty((0,), device=device, dtype=torch.long) + + for idx, video in enumerate(group): + # (F, C, H, W) + video = video.to(dtype) + video = rearrange(video, "c f h w -> f c h w").contiguous() + + height, width = video.shape[-2:] + num_patch_h, num_patch_w = ( + height // self.patch_size, + width // self.patch_size, + ) + + # (F, C, H, W) -> (F, D, H//P, W//P) + seq = self.proj(video) + # (F, D, H//P, W//P) -> (F, T, D) + seq = seq.flatten(2).transpose(1, 2) + + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if height != self.base_height or width != self.base_width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.embed_dim, + grid_size=(num_patch_h, num_patch_w), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) # (T, D) + pos_embed = torch.from_numpy(pos_embed).to(device) + + else: + pos_embed = self.pos_embed + + assert ( + pos_embed.shape[0] == seq.shape[1] + ), "pos_embed and sequence token length mismatch" + + if self.token_dropout_rate > 0: + selected_len = int( + seq.shape[1] * (1 - self.token_dropout_rate) + ) + select_indices = torch.randperm( + seq.shape[1], device=device + )[:selected_len] + seq = seq[:, select_indices] + pos_embed = pos_embed[select_indices] + token_kept_ids[-1].append(select_indices) + + num_patches[-1].append(seq.shape[1]) + + video_seq = torch.cat([video_seq, seq], dim=1) + video_pos = torch.cat([video_pos, pos_embed], dim=0) + video_idx = torch.cat( + ( + video_idx, + torch.full( + (seq.shape[1],), + idx, + device=device, + dtype=torch.long, + ), # (T,) + ) + ) + batched_video.append(rearrange(video_seq, "f t d -> t (f d)")) + batched_pos.append(video_pos) + batched_idx.append(video_idx) + # [t1, t2, t3, ...] + batched_len.append(video_seq.shape[1]) + if video_seq.shape[1] > max_len: + max_len = video_seq.shape[1] + + # (b, T, (F * D)) + batched_video = nn.utils.rnn.pad_sequence( + batched_video, batch_first=True + ) + # (b, F, T, D) + batched_video = rearrange( + batched_video, "b t (f d) -> b f t d", f=num_frame + ) + + # (b, T, D) + batched_pos = nn.utils.rnn.pad_sequence(batched_pos, batch_first=True) + # (b, 1, T, D) + batched_pos = batched_pos.unsqueeze(1) + + # (b, T) + batched_idx = nn.utils.rnn.pad_sequence( + batched_idx, batch_first=True, padding_value=-1 + ) + + if self.layer_norm: + batched_video = self.norm(batched_video) + + if not self.token_dropout_rate > 0: + token_kept_ids = None + + return ( + (batched_video + batched_pos).to(dtype), + batched_idx, + video_ids, + num_patches, + token_kept_ids, + ) + +class NaViTLatteT2V(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + patch_size_t: int = 1, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + video_length: int = 16, + attention_mode: str = 'flash', + use_rope: bool = False, + rope_scaling_type: str = 'linear', + compress_kv_factor: int = 1, + interpolation_scale_1d: float = None, + max_token_lim: int = 1024, + token_dropout_rate: float = 0.0, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.video_length = video_length + self.use_rope = use_rope + self.compress_kv_factor = compress_kv_factor + self.num_layers = num_layers + + assert self.compress_kv_factor == 1 and not use_rope, "NaViT currently does not support compressing kv or using rope" + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + # self.is_input_patches = in_channels is not None and patch_size is not None + self.is_input_patches = True + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + # 2. Define input layers + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size[0] + self.width = sample_size[1] + + self.patch_size = patch_size + interpolation_scale_2d = self.config.sample_size[0] // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale_2d = max(interpolation_scale_2d, 1) + self.pos_embed = ExamplePacking( + # position encoding will interpolate to sample_size which is determined by args.max_image_size + base_height=self.height, + base_width=self.width, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale_2d, + max_token_lim=max_token_lim, + token_dropout_rate=token_dropout_rate, + ) + + + # define temporal positional embedding + if interpolation_scale_1d is None: + if self.config.video_length % 2 == 1: + interpolation_scale_1d = (self.config.video_length - 1) // 16 # => 16 (= 16 Latte) has interpolation scale 1 + else: + interpolation_scale_1d = self.config.video_length // 16 # => 16 (= 16 Latte) has interpolation scale 1 + # interpolation_scale_1d = self.config.video_length // 5 # + interpolation_scale_1d = max(interpolation_scale_1d, 1) + temp_pos_embed = get_1d_sincos_pos_embed(inner_dim, video_length, interpolation_scale=interpolation_scale_1d) # 1152 hidden size + self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + + rope_scaling = None + if self.use_rope: + self.position_getter_2d = PositionGetter2D() + self.position_getter_1d = PositionGetter1D() + rope_scaling = dict(type=rope_scaling_type, factor_2d=interpolation_scale_2d, factor_1d=interpolation_scale_1d) + + # 3. Define transformers blocks, spatial attention + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=(compress_kv_factor, compress_kv_factor) if d >= num_layers // 2 and compress_kv_factor != 1 else None, # follow pixart-sigma, apply in second-half layers + ) + for d in range(num_layers) + ] + ) + + # Define temporal transformers blocks + self.temporal_transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock_( # one attention + inner_dim, + num_attention_heads, # num_attention_heads + attention_head_dim, # attention_head_dim 72 + dropout=dropout, + cross_attention_dim=None, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=False, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + attention_mode=attention_mode, + use_rope=use_rope, + rope_scaling=rope_scaling, + compress_kv_factor=(compress_kv_factor, ) if d >= num_layers // 2 and compress_kv_factor != 1 else None, # follow pixart-sigma, apply in second-half layers + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim ** 0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + # self.use_additional_conditions = self.config.sample_size[0] == 128 # False, 128 -> 1024 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def make_position(self, b, t, use_image_num, h, w, device): + pos_hw = self.position_getter_2d(b*(t+use_image_num), h, w, device) # fake_b = b*(t+use_image_num) + pos_t = self.position_getter_1d(b*h*w, t, device) # fake_b = b*h*w + return pos_hw, pos_t + + def make_attn_mask(self, attention_mask, frame, dtype): + attention_mask = rearrange(attention_mask, 'b t h w -> (b t) 1 (h w)') + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(dtype)) * -10000.0 + attention_mask = attention_mask.to(self.dtype) + return attention_mask + + def vae_to_diff_mask(self, attention_mask, use_image_num): + dtype = attention_mask.dtype + # b, t+use_image_num, h, w, assume t as channel + # this version do not use 3d patch embedding + attention_mask = F.max_pool2d(attention_mask, kernel_size=(self.patch_size, self.patch_size), stride=(self.patch_size, self.patch_size)) + attention_mask = attention_mask.bool().to(dtype) + return attention_mask + + def forward( + self, + hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + enable_temporal_attentions: bool = True, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states: List of tensors. B(C, 1+F+num_img, H, W). + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + input_batch_size = len(hidden_states) + # (b, F, T, D) + (hidden_states, batched_idx_video, video_ids, num_patches, token_kept_ids) = ( + self.pos_embed(hidden_states, dtype=self.dtype) + ) # alrady add positional embeddings + + packed_batch_size, frame, seq_len, _ = hidden_states.shape + frame = frame - use_image_num # 20-4=16s + + # (b*F, T, D) + hidden_states = rearrange(hidden_states, "b f t d -> (b f) t d") + + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + + # (B, D * 6), (B, D) + timestep, embedded_timestep = self.adaln_single( + timestep, + added_cond_kwargs, + batch_size=input_batch_size, + hidden_dtype=hidden_states.dtype, + ) + + # (b, T, D * 6) + timestep = pack_timestep_as(timestep, video_ids, num_patches) + embedded_timestep = pack_timestep_as( + embedded_timestep, video_ids, num_patches + ) + assert timestep.shape[1] == hidden_states.shape[1] + assert embedded_timestep.shape[1] == hidden_states.shape[1] + + + # 1 + 4, 1 -> video condition, 4 -> image condition + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint + encoder_attention_mask = repeat(encoder_attention_mask, 'b l -> b f l', f=frame).contiguous() + elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint + encoder_attention_mask_video = encoder_attention_mask[:, :1, ...] + encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', + f=frame).contiguous() + encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...] + encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1) + + # 2. Blocks + if self.caption_projection is not None: + # (B, L, D) or (B, 1+num_image, L, D) + encoder_hidden_states = self.caption_projection(encoder_hidden_states.to(self.dtype)) # 3 120 1152 + + # (B, F, L, D) + if use_image_num != 0 and self.training: + encoder_hidden_states_video = encoder_hidden_states[:, :1, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + else: + encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> b f t d', f=frame).contiguous() + + # (b, F, L', D), (b, F, L') + encoder_hidden_states, batched_idx_text = pack_image_joint_text_as( + encoder_hidden_states, encoder_attention_mask, video_ids + ) + encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous() + + # compute self-attn mask + assert attention_mask is None, "NaViT does not support attention_mask!" + # (b, T, T) True for keep and False for discard + attention_mask = batched_idx_video.unsqueeze( + -1 + ) == batched_idx_video.unsqueeze(1) + # (b, T) True for keep and False for discard + padding_mask_1d = batched_idx_video >= 0 + # (b, T, T) + padding_mask = padding_mask_1d.unsqueeze( + -1 + ) & padding_mask_1d.unsqueeze(1) + # (b, T, T) + attention_mask = attention_mask & padding_mask + + # compute cross-attn mask with text condition + # (b, F, T, L) True for keep and False for discard + encoder_attention_mask = batched_idx_video.unsqueeze(1).unsqueeze( + -1 + ) == batched_idx_text.unsqueeze(2) + + assert ( + encoder_attention_mask.shape[-1] + == encoder_hidden_states_spatial.shape[-2] + ) + + # convert bool mask to bias + attention_mask = mask_to_bias(attention_mask, hidden_states.dtype) + encoder_attention_mask = mask_to_bias( + encoder_attention_mask, hidden_states.dtype + ) + + attention_mask = repeat( + attention_mask, "b t l -> (b f) t l", f=frame + use_image_num + ) + encoder_attention_mask = rearrange( + encoder_attention_mask, "b f t l -> (b f) t l" + ) + + # prepare timesteps for spatial and temporal block + timestep_spatial = repeat( + timestep, "b t d -> (b f) t d", f=frame + use_image_num + ).contiguous() + timestep_temp = rearrange( + timestep, "b t d -> (b t) d", t=seq_len + ).contiguous() + + + for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)): + + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + spatial_block, + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + if enable_temporal_attentions: + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=packed_batch_size).contiguous() + + if use_image_num != 0: # image-video joitn training + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=packed_batch_size).contiguous() + + else: + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = torch.utils.checkpoint.checkpoint( + temp_block, + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + use_reentrant=False, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=packed_batch_size).contiguous() + else: + hidden_states = spatial_block( + hidden_states, + attention_mask, + encoder_hidden_states_spatial, + encoder_attention_mask, + timestep_spatial, + cross_attention_kwargs, + class_labels, + ) + + if enable_temporal_attentions: + # b c f h w, f = 16 + 4 + hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=packed_batch_size).contiguous() + + if use_image_num != 0 and self.training: + hidden_states_video = hidden_states[:, :frame, ...] + hidden_states_image = hidden_states[:, frame:, ...] + + # if i == 0 and not self.use_rope: + # hidden_states_video = hidden_states_video + self.temp_pos_embed + + hidden_states_video = temp_block( + hidden_states_video, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=packed_batch_size).contiguous() + + else: + # if i == 0 and not self.use_rope: + if i == 0: + hidden_states = hidden_states + self.temp_pos_embed + + hidden_states = temp_block( + hidden_states, + None, # attention_mask + None, # encoder_hidden_states + None, # encoder_attention_mask + timestep_temp, + cross_attention_kwargs, + class_labels, + ) + + hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', + b=packed_batch_size).contiguous() + + + + embedded_timestep = repeat( + embedded_timestep, + "b t d -> (b f) t d", + f=frame + use_image_num, + ).contiguous() + params = ( + self.scale_shift_table[None, None] # [1, 1, 2, D] + + embedded_timestep[:, :, None] # [b*F, T, 1, D] + ).chunk(2, dim=-2) + shift, scale = [param.squeeze(-2) for param in params] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + + # [b*F, T, p*p*c] + hidden_states = self.proj_out(hidden_states) + # [b, F, T, p*p*c] + hidden_states = rearrange( + hidden_states, "(b f) t d -> b f t d", b=packed_batch_size + ) + # padding_mask_1d: (b, T) True for keep and False for discard + # make sure padded token filled with 0 + # (b, 1, T, 1) + padding_mask = padding_mask_1d[:, None, :, None] + + hidden_states = hidden_states.masked_fill(~padding_mask, 0) + + return hidden_states, video_ids, token_kept_ids + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + + model = cls.from_config(config, **kwargs) + return model + +# depth = num_layers * 2 +def NaViTLatteT2V_XL_122(**kwargs): + return NaViTLatteT2V(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) +def NaViTLatteT2V_D64_XL_122(**kwargs): + return NaViTLatteT2V(num_layers=28, attention_head_dim=64, num_attention_heads=18, patch_size_t=1, patch_size=2, + norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) + +Latte_navit_models = { + "NaViTLatteT2V-XL/122": NaViTLatteT2V_XL_122, + "NaViTLatteT2V-D64-XL/122": NaViTLatteT2V_D64_XL_122, +} diff --git a/opensora/models/diffusion/latte/modules.py b/opensora/models/diffusion/latte/modules.py index 40ff32e27..1820daf28 100644 --- a/opensora/models/diffusion/latte/modules.py +++ b/opensora/models/diffusion/latte/modules.py @@ -1583,9 +1583,18 @@ def forward( elif self.use_layer_norm: norm_hidden_states = self.norm1(hidden_states) elif self.use_ada_layer_norm_single: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( - self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) - ).chunk(6, dim=1) + # (B, T, 6 * D) when NaViT spatial attention + if timestep.ndim == 3: + # (B, T, 1, D) + params = ( + self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1) + ).chunk(6, dim=-2) + # (B, T, D) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [param.squeeze(-2) for param in params] + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) norm_hidden_states = self.norm1(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa norm_hidden_states = norm_hidden_states.squeeze(1) diff --git a/opensora/train/train_t2v_navit.py b/opensora/train/train_t2v_navit.py new file mode 100644 index 000000000..cfd2bf92e --- /dev/null +++ b/opensora/train/train_t2v_navit.py @@ -0,0 +1,784 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +A minimal training script for DiT using PyTorch DDP. +""" +import argparse +import logging +import math +import os +import shutil +from pathlib import Path +from typing import Optional +import gc +import numpy as np +from einops import rearrange +from tqdm import tqdm +from dataclasses import field, dataclass +from torch.utils.data import DataLoader +from copy import deepcopy +import accelerate +import torch +from torch.nn import functional as F +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import create_repo +from packaging import version +from tqdm.auto import tqdm +from transformers import HfArgumentParser, TrainingArguments, AutoTokenizer + +import diffusers +from diffusers import DDPMScheduler, PNDMScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel, compute_snr +from diffusers.utils import check_min_version, is_wandb_available + +from opensora.dataset import getdataset, ae_denorm +from opensora.models.ae import getae, getae_wrapper +from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper +from opensora.models.diffusion.diffusion import create_diffusion_navit as create_diffusion +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.models.text_encoder import get_text_enc, get_text_warpper +from opensora.utils.dataset_utils import NaViTCollate +from opensora.models.ae import ae_stride_config, ae_channel_config +from opensora.models.diffusion import Diffusion_models +from opensora.sample.pipeline_videogen import VideoGenPipeline +from opensora.utils.utils import print_grad_norm + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") +logger = get_logger(__name__) + + +@torch.inference_mode() +def log_validation(args, model, vae, text_encoder, tokenizer, accelerator, weight_dtype, global_step): + validation_prompt = [ + "A quiet beach at dawn, the waves gently lapping at the shore and the sky painted in pastel hues.", + "The majestic beauty of a waterfall cascading down a cliff into a serene lake." + ] + logger.info(f"Running validation....\n") + model = accelerator.unwrap_model(model) + scheduler = PNDMScheduler() + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=model).to(device=accelerator.device) + videos = [] + for prompt in validation_prompt: + logger.info('Processing the ({}) prompt'.format(prompt)) + video = videogen_pipeline(prompt, + num_frames=args.num_frames, + height=args.max_image_size, + width=args.max_image_size, + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + enable_temporal_attentions=True, + num_images_per_prompt=1, + mask_feature=True, + ).video + videos.append(video[0]) + # import ipdb;ipdb.set_trace() + gc.collect() + torch.cuda.empty_cache() + videos = torch.stack(videos).numpy() + videos = rearrange(videos, 'b t h w c -> b t c h w') + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_videos = np.stack([np.asarray(vid) for vid in videos]) + tracker.writer.add_video("validation", np_videos, global_step, fps=24) + if tracker.name == "wandb": + import wandb + tracker.log( + { + "validation": [ + wandb.Video(video, caption=f"{i}: {prompt}", fps=24) + for i, (video, prompt) in enumerate(zip(videos, validation_prompt)) + ] + } + ) + + del videogen_pipeline + gc.collect() + torch.cuda.empty_cache() +################################################################################# +# Training Loop # +################################################################################# + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # if args.push_to_hub: + # repo_id = create_repo( + # repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + # ).repo_id + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Create model: + diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule + kwargs = {} + ae = getae_wrapper(args.ae)(args.ae_path, cache_dir=args.cache_dir, **kwargs).eval() + if args.enable_tiling: + ae.vae.enable_tiling() + ae.vae.tile_overlap_factor = args.tile_overlap_factor + + kwargs = {'load_in_8bit': args.enable_8bit_t5, 'torch_dtype': weight_dtype, 'low_cpu_mem_usage': True} + text_enc = get_text_warpper(args.text_encoder_name)(args, **kwargs).eval() + + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + ae.vae_scale_factor = ae_stride_config[args.ae] + assert ae_stride_h == ae_stride_w, f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" + args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ae_stride_t, ae_stride_h, ae_stride_w + args.ae_stride = args.ae_stride_h + patch_size = args.model[-3:] + patch_size_t, patch_size_h, patch_size_w = int(patch_size[0]), int(patch_size[1]), int(patch_size[2]) + args.patch_size = patch_size_h + args.patch_size_t, args.patch_size_h, args.patch_size_w = patch_size_t, patch_size_h, patch_size_w + assert patch_size_h == patch_size_w, f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" + # assert args.num_frames % ae_stride_t == 0, f"Num_frames must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert args.max_image_size % ae_stride_h == 0, f"Image size must be divisible by ae_stride_h, but found max_image_size ({args.max_image_size}), ae_stride_h ({ae_stride_h})." + + args.stride_t = ae_stride_t * patch_size_t + args.stride = ae_stride_h * patch_size_h + latent_size = (args.max_image_size // ae_stride_h, args.max_image_size // ae_stride_w) + ae.latent_size = latent_size + + if getae_wrapper(args.ae) == CausalVQVAEModelWrapper or getae_wrapper(args.ae) == CausalVAEModelWrapper: + args.video_length = video_length = args.num_frames // ae_stride_t + 1 + else: + video_length = args.num_frames // ae_stride_t + model = Diffusion_models[args.model]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type='default', + video_length=video_length, + attention_mode=args.attention_mode, + compress_kv_factor=args.compress_kv_factor, + use_rope=args.use_rope, + max_token_lim=args.max_token_lim, + token_dropout_rate=args.token_dropout_rate, + ) + model.gradient_checkpointing = args.gradient_checkpointing + + # # use pretrained model? + if args.pretrained: + if 'safetensors' in args.pretrained: + from safetensors.torch import load_file as safe_load + checkpoint = safe_load(args.pretrained, device="cpu") + else: + checkpoint = torch.load(args.pretrained, map_location='cpu')['model'] + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + logger.info(f'missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}') + logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!') + + # Freeze vae and text encoders. + ae.requires_grad_(False) + text_enc.requires_grad_(False) + # Set model as trainable. + model.train() + + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + # ae.to(accelerator.device, dtype=torch.float32) + ae.to(accelerator.device, dtype=weight_dtype) + # ae.to(accelerator.device) + text_enc.to(accelerator.device, dtype=weight_dtype) + # text_enc.to(accelerator.device) + + # Create EMA for the unet. + if args.use_ema: + ema_model = deepcopy(model) + ema_model = EMAModel(ema_model.parameters(), model_cls=LatteT2V, model_config=ema_model.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "model_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "model")) + if weights: # Don't pop if empty + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "model_ema"), LatteT2V) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = LatteT2V.from_pretrained(input_dir, subfolder="model") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = model.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Setup data: + train_dataset = getdataset(args) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=NaViTCollate(args), + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.output_dir, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + for step, (x_list, input_ids, cond_mask) in enumerate(train_dataloader): + with accelerator.accumulate(model): + # Sample noise that we'll add to the latents + + # B [C 1+T+num_images H W] + x_list = list(map(lambda x: x.to(accelerator.device, dtype=weight_dtype), x_list)) + attn_mask = None + # assert torch.all(attn_mask != 0), 'attn_mask must all 1' + input_ids = input_ids.to(accelerator.device) # B L or B 1+num_images L + cond_mask = cond_mask.to(accelerator.device) # B L or B 1+num_images L + # print('x.shape, attn_mask.shape, input_ids.shape, cond_mask.shape', x.shape, attn_mask.shape, input_ids.shape, cond_mask.shape) + + with torch.no_grad(): + # use for loop to avoid OOM, because T5 is too huge... + B, _, _ = input_ids.shape # B T+num_images L b 1+4, L + cond = torch.stack([text_enc(input_ids[i], cond_mask[i]) for i in range(B)]) # B 1+num_images L D + + latent_list = [] + # Map input images to latent space + normalize latents + if args.use_image_num == 0: + for x in x_list: + x = ae.encode(x.unsqueeze(0)) # 1 C T H W + latent_list.append(x.squeeze(0)) + else: + for x in x_list: + x = x.unsqueeze(0) # 1 C T H W + videos, images = x[:, :, :-args.use_image_num], x[:, :, -args.use_image_num:] + videos = ae.encode(videos) # 1 C T H W + + images = rearrange(images, 'b c t h w -> (b t) c 1 h w') + images = ae.encode(images) + + images = rearrange(images, '(b t) c 1 h w -> b c t h w', t=args.use_image_num) + x = torch.cat([videos, images], dim=2) # b c 17+4, h, w + latent_list.append(x.squeeze(0)) + + + + # print('(x.shape, attn_mask.shape, cond.shape, cond_mask.shape', x.shape, attn_mask.shape, cond.shape, cond_mask.shape) + model_kwargs = dict(encoder_hidden_states=cond, attention_mask=attn_mask, + encoder_attention_mask=cond_mask, use_image_num=args.use_image_num) + t = torch.randint(0, diffusion.num_timesteps, (len(x_list),), device=accelerator.device) + loss_dict = diffusion.training_losses(model, latent_list, t, model_kwargs) + loss = loss_dict["loss"] + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + + + # accelerator.deepspeed_engine_wrapped.engine.backward(loss) + # print_grad_norm(model) + # accelerator.deepspeed_engine_wrapped.engine.step() + + if accelerator.sync_gradients: + params_to_clip = model.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if args.use_deepspeed or accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_model.store(model.parameters()) + ema_model.copy_to(model.parameters()) + + if args.enable_tracker: + log_validation(args, model, ae, text_enc.text_enc, train_dataset.tokenizer, accelerator, weight_dtype, global_step) + + accelerator.wait_for_everyone() + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--video_data", type=str, required='') + parser.add_argument("--image_data", type=str, default='') + parser.add_argument("--sample_rate", type=int, default=1) + parser.add_argument("--num_frames", type=int, default=17) + parser.add_argument("--max_image_size", type=int, default=512) + parser.add_argument("--use_img_from_vid", action="store_true") + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--model_max_length", type=int, default=300) + + parser.add_argument('--enable_8bit_t5', action='store_true') + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + parser.add_argument("--compress_kv", action="store_true") + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="xformers") + parser.add_argument('--use_rope', action='store_true') + parser.add_argument('--compress_kv_factor', type=int, default=1) + + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="Latte-XL/122") + parser.add_argument("--pretrained", type=str, default=None) + parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--cache_dir", type=str, default='./cache_dir') + + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument('--guidance_scale', type=float, default=5.5) + parser.add_argument("--multi_scale", action="store_true") + parser.add_argument("--enable_tracker", action="store_true") + parser.add_argument("--use_deepspeed", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--max_token_lim", + type=int, + default=1024, + help="The max token limit of NaViT training.", + ) + parser.add_argument( + "--token_dropout_rate", + type=float, + default=0, + help="The drop rate of token in NaViT training.", + ) + + args = parser.parse_args() + main(args) diff --git a/opensora/utils/dataset_utils.py b/opensora/utils/dataset_utils.py index 1309989b1..67adfed84 100644 --- a/opensora/utils/dataset_utils.py +++ b/opensora/utils/dataset_utils.py @@ -3,7 +3,7 @@ import decord from torch.nn import functional as F import torch - +from opensora.dataset.transform import resize IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG'] @@ -158,3 +158,82 @@ def process(self, batch_tubes, t_ds_stride, ds_stride, max_thw, ae_stride_thw, p # attention_mask = torch.stack(attention_mask) # b t h w return pad_batch_tubes, attention_mask + + + +class NaViTCollate(Collate): + def __init__(self, args): + self.max_image_size = args.max_image_size + self.ae_stride = args.ae_stride + self.ae_stride_t = args.ae_stride_t + self.ae_stride_thw = (self.ae_stride_t, self.ae_stride, self.ae_stride) + self.ae_stride_1hw = (1, self.ae_stride, self.ae_stride) + + self.patch_size = args.patch_size + self.patch_size_t = args.patch_size_t + self.patch_size_thw = (self.patch_size_t, self.patch_size, self.patch_size) + self.patch_size_1hw = (1, self.patch_size, self.patch_size) + + self.num_frames = args.num_frames + self.use_image_num = args.use_image_num + self.max_thw = (self.num_frames, self.max_image_size, self.max_image_size) + self.max_1hw = (1, self.max_image_size, self.max_image_size) + + def process(self, ds_stride, batch_tubes_vid, batch_tubes_img=None): + """ + Resize video to minimum multiple of ds_stride + Args: + batch_tubes_img List[torch.tensor]: Video to be resized. Size of B(C, T, H, W) + batch_tubes_img List[torch.tensor]: Image to be resized. Size is B*num_img(C, 1, H, W) + Returns: + torch.tensor: Resized video batches. + size is B(C, 1+T+num_img, new_h, new_w) + """ + batch_tubes_resized = [] + for i, vid in enumerate(batch_tubes_vid): + vid_h,vid_w = vid.shape[-2:] + + if vid_h%ds_stride==0: + new_h=vid_h + else: + new_h=(vid_h//ds_stride+1)*ds_stride + if vid_w%ds_stride==0: + new_w=vid_w + else: + new_w=(vid_w//ds_stride+1)*ds_stride + + vid_resized = resize(vid, target_size=(new_h,new_w), + interpolation_mode="bilinear") + + vid_resized = F.pad(vid_resized, + (0, 0, + 0, 0, + 0, 1), value=0) + if batch_tubes_img is not None: + imgs = batch_tubes_img[i*self.use_image_num: (i+1)*self.use_image_num] + imgs_resized = [resize(img, target_size=(new_h,new_w), + interpolation_mode="bilinear") for img in imgs] + vid_resized = torch.cat([vid_resized,] + imgs_resized, dim=1) + + batch_tubes_resized.append(vid_resized) + + return batch_tubes_resized + + def __call__(self, batch): + batch_tubes_vid, input_ids_vid, cond_mask_vid, batch_tubes_img, input_ids_img, cond_mask_img = self.package(batch) + + # import ipdb;ipdb.set_trace() + ds_stride = self.ae_stride * self.patch_size + t_ds_stride = self.ae_stride_t * self.patch_size_t + if self.use_image_num == 0: + batch_tubes = self.process(ds_stride, batch_tubes_vid) + # attention_mask: b t h w + input_ids, cond_mask = input_ids_vid.squeeze(1), cond_mask_vid.squeeze(1) # b 1 l -> b l + else: + # B(C, 1+T+num_img, new_h, img_w) + batch_tubes = self.process(ds_stride, batch_tubes_vid, batch_tubes_img) + input_ids = torch.cat([input_ids_vid, input_ids_img], dim=1) # b 1+num_img hw + cond_mask = torch.cat([cond_mask_vid, cond_mask_img], dim=1) # b 1+num_img hw + return batch_tubes, input_ids, cond_mask + + diff --git a/scripts/text_condition/train_videoae_65_navit.sh b/scripts/text_condition/train_videoae_65_navit.sh new file mode 100644 index 000000000..cfa8d0c87 --- /dev/null +++ b/scripts/text_condition/train_videoae_65_navit.sh @@ -0,0 +1,38 @@ +export WANDB_KEY="" +export ENTITY="linbin" +export PROJECT="65x512x512_10node_bs2_lr2e-5_4img" +accelerate launch \ + --config_file scripts/accelerate_configs/deepspeed_zero2_config.yaml \ + opensora/train/train_t2v_navit.py \ + --model NaViTLatteT2V-XL/122 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --cache_dir "./cache_dir" \ + --dataset t2v_navit \ + --ae CausalVAEModel_4x8x8 \ + --ae_path "/mnt_zhexin/zhexin.lzx/models/Open-Sora-Plan-v1.1.0/vae/" \ + --video_data "scripts/train_data/video_data_debug.txt" \ + --image_data "scripts/train_data/image_data_debug.txt" \ + --sample_rate 1 \ + --num_frames 65 \ + --max_image_size 512 \ + --gradient_checkpointing \ + --attention_mode xformers \ + --train_batch_size=4 \ + --dataloader_num_workers 8 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="65x512x512_10node_bs2_lr2e-5_4img" \ + --allow_tf32 \ + --use_deepspeed \ + --max_token_lim 4096 \ + --token_dropout_rate 0.75 \ + --use_image_num 4 \ + --enable_tiling \ + --enable_tracker \ + --resume_from_checkpoint "latest" diff --git a/scripts/text_condition/train_videoae_65_navit_test.sh b/scripts/text_condition/train_videoae_65_navit_test.sh new file mode 100644 index 000000000..d16ab3c25 --- /dev/null +++ b/scripts/text_condition/train_videoae_65_navit_test.sh @@ -0,0 +1,37 @@ +export WANDB_KEY="" +export ENTITY="linbin" +export PROJECT="65x512x512_10node_bs2_lr2e-5_4img" +export CUDA_VISIBLE_DEVICES=1 +python tests/test_navit_consistency.py \ + --model NaViTLatteT2V-XL/122 \ + --text_encoder_name DeepFloyd/t5-v1_1-xxl \ + --cache_dir "./cache_dir" \ + --dataset t2v \ + --ae CausalVAEModel_4x8x8 \ + --ae_path "/mnt_zhexin/zhexin.lzx/models/Open-Sora-Plan-v1.1.0/vae/" \ + --video_data "scripts/train_data/video_data_debug.txt" \ + --image_data "scripts/train_data/image_data_debug.txt" \ + --sample_rate 1 \ + --num_frames 65 \ + --max_image_size 512 \ + --gradient_checkpointing \ + --attention_mode xformers \ + --train_batch_size=4 \ + --dataloader_num_workers 8 \ + --gradient_accumulation_steps=1 \ + --max_train_steps=1000000 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --mixed_precision="bf16" \ + --report_to="wandb" \ + --checkpointing_steps=500 \ + --output_dir="65x512x512_10node_bs2_lr2e-5_4img" \ + --allow_tf32 \ + --use_deepspeed \ + --max_token_lim 2048 \ + --token_dropout_rate 0.0 \ + --use_image_num 4 \ + --enable_tiling \ + --enable_tracker \ + --resume_from_checkpoint "latest" diff --git a/scripts/train_data/image_data_debug.txt b/scripts/train_data/image_data_debug.txt new file mode 100644 index 000000000..b49aebcee --- /dev/null +++ b/scripts/train_data/image_data_debug.txt @@ -0,0 +1 @@ +/mnt_zhexin/zhexin.lzx/dataset/Open-Sora-Plan-v1.1.0/human_image/images,/mnt_zhexin/zhexin.lzx/dataset/Open-Sora-Plan-v1.1.0/anno_jsons/human_images_162094.json \ No newline at end of file diff --git a/scripts/train_data/video_data_debug.txt b/scripts/train_data/video_data_debug.txt new file mode 100644 index 000000000..40c09881c --- /dev/null +++ b/scripts/train_data/video_data_debug.txt @@ -0,0 +1 @@ +/mnt_zhexin/zhexin.lzx/dataset/Open-Sora-Plan-v1.1.0/all_mixkit,/mnt_zhexin/zhexin.lzx/dataset/Open-Sora-Plan-v1.1.0/anno_jsons/video_mixkit_65f_54735.json \ No newline at end of file diff --git a/tests/test_navit_consistency.py b/tests/test_navit_consistency.py new file mode 100644 index 000000000..e248eea1b --- /dev/null +++ b/tests/test_navit_consistency.py @@ -0,0 +1,622 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test NaViT output consistency with the original implementation. +""" +import argparse +import logging +import math +import os +import shutil +from copy import deepcopy +from logging import getLogger +from pathlib import Path + +import accelerate +import diffusers +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import check_min_version, is_wandb_available +from einops import rearrange +from huggingface_hub import create_repo +from packaging import version +from tqdm import tqdm +from tqdm.auto import tqdm +from transformers import AutoTokenizer + +from opensora.dataset import ae_denorm, getdataset +from opensora.models.ae import ae_channel_config, ae_stride_config, getae, getae_wrapper +from opensora.models.ae.videobase import CausalVAEModelWrapper, CausalVQVAEModelWrapper +from opensora.models.diffusion import Diffusion_models +from opensora.utils.dataset_utils import NaViTCollate +from opensora.models.diffusion.diffusion import ( + create_diffusion_navit as create_diffusion, +) +from opensora.models.diffusion.latte.modeling_latte import LatteT2V +from opensora.models.diffusion.latte.modeling_latte_navit import ( + NaViTLatteT2V, + pack_target_as, +) +from opensora.models.text_encoder import get_text_enc + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0") +logger = getLogger(__name__) + + +def generate_timestep_weights(args, num_timesteps): + weights = torch.ones(num_timesteps) + + # Determine the indices to bias + num_to_bias = int(args.timestep_bias_portion * num_timesteps) + + if args.timestep_bias_strategy == "later": + bias_indices = slice(-num_to_bias, None) + elif args.timestep_bias_strategy == "earlier": + bias_indices = slice(0, num_to_bias) + elif args.timestep_bias_strategy == "range": + # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. + range_begin = args.timestep_bias_begin + range_end = args.timestep_bias_end + if range_begin < 0: + raise ValueError( + "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." + ) + if range_end > num_timesteps: + raise ValueError( + "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." + ) + bias_indices = slice(range_begin, range_end) + else: # 'none' or any other string + return weights + if args.timestep_bias_multiplier <= 0: + return ValueError( + "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." + " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." + " A timestep bias multiplier less than or equal to 0 is not allowed." + ) + + # Apply the bias + weights[bias_indices] *= args.timestep_bias_multiplier + + # Normalize + weights /= weights.sum() + + return weights + + +################################################################################# +# Training Loop # +################################################################################# + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError( + "Make sure to install wandb if you want to use it for logging during training." + ) + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Create model: + + diffusion = create_diffusion( + timestep_respacing="", + ) # default: 1000 steps, linear noise schedule + ae = getae_wrapper(args.ae)( + args.ae_path, subfolder="vae", cache_dir="cache_dir" + ).eval() + if args.enable_tiling: + ae.vae.enable_tiling() + ae.vae.tile_overlap_factor = args.tile_overlap_factor + text_enc = get_text_enc(args).eval() + + ae_stride_t, ae_stride_h, ae_stride_w = ae_stride_config[args.ae] + args.ae_stride_t, args.ae_stride_h, args.ae_stride_w = ( + ae_stride_t, + ae_stride_h, + ae_stride_w, + ) + args.ae_stride = args.ae_stride_h + patch_size = args.model[-3:] + patch_size_t, patch_size_h, patch_size_w = ( + int(patch_size[0]), + int(patch_size[1]), + int(patch_size[2]), + ) + args.patch_size = patch_size_h + args.patch_size_t, args.patch_size_h, args.patch_size_w = ( + patch_size_t, + patch_size_h, + patch_size_w, + ) + assert ( + ae_stride_h == ae_stride_w + ), f"Support only ae_stride_h == ae_stride_w now, but found ae_stride_h ({ae_stride_h}), ae_stride_w ({ae_stride_w})" + assert ( + patch_size_h == patch_size_w + ), f"Support only patch_size_h == patch_size_w now, but found patch_size_h ({patch_size_h}), patch_size_w ({patch_size_w})" + # assert args.num_frames % ae_stride_t == 0, f"Num_frames must be divisible by ae_stride_t, but found num_frames ({args.num_frames}), ae_stride_t ({ae_stride_t})." + assert ( + args.max_image_size % ae_stride_h == 0 + ), f"Image size must be divisible by ae_stride_h, but found max_image_size ({args.max_image_size}), ae_stride_h ({ae_stride_h})." + + latent_size = ( + args.max_image_size // ae_stride_h, + args.max_image_size // ae_stride_w, + ) + + if ( + getae_wrapper(args.ae) == CausalVQVAEModelWrapper + or getae_wrapper(args.ae) == CausalVAEModelWrapper + ): + args.video_length = video_length = args.num_frames // ae_stride_t + 1 + else: + video_length = args.num_frames // ae_stride_t + + model = Diffusion_models["LatteT2V-XL/122"]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type="default", + video_length=video_length, + attention_mode=args.attention_mode, + ) + new_model: torch.nn.Module = Diffusion_models["NaViTLatteT2V-XL/122"]( + in_channels=ae_channel_config[args.ae], + out_channels=ae_channel_config[args.ae] * 2, + # caption_channels=4096, + # cross_attention_dim=1152, + attention_bias=True, + sample_size=latent_size, + num_vector_embeds=None, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + use_linear_projection=False, + only_cross_attention=False, + double_self_attention=False, + upcast_attention=False, + # norm_type="ada_norm_single", + norm_elementwise_affine=False, + norm_eps=1e-6, + attention_type="default", + video_length=video_length, + attention_mode=args.attention_mode, + max_token_lim=args.max_token_lim, + token_dropout_rate=args.token_dropout_rate, + ) + + # # use pretrained model? + if args.pretrained: + if "safetensors" in args.pretrained: + from safetensors.torch import load_file as safe_load + + checkpoint = safe_load(args.pretrained, device="cpu") + else: + checkpoint = torch.load(args.pretrained, map_location="cpu")[ + "model" + ] + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint, strict=False + ) + logger.info( + f"missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}" + ) + logger.info( + f"Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys from {args.pretrained}!" + ) + # load from pixart-alpha + # pixelart_alpha = torch.load(args.pretrained, map_location='cpu')['state_dict'] + # checkpoint = {} + # for k, v in pixelart_alpha.items(): + # if 'x_embedder' in k or 't_embedder' in k or 'y_embedder' in k: + # checkpoint[k] = v + # if k.startswith('blocks'): + # k_spilt = k.split('.') + # blk_id = str(int(k_spilt[1]) * 2) + # k_spilt[1] = blk_id + # new_k = '.'.join(k_spilt) + # checkpoint[new_k] = v + # missing_keys, unexpected_keys = model.load_state_dict(checkpoint, strict=False) + # logger.info(f'Successfully load {len(model.state_dict()) - len(missing_keys)} keys from {args.pretrained}!') + + # copy weight from model to new_model + logger.info("Copying weight from model to new_model...") + model_state_dict = model.state_dict() + missing_keys, unexpected_keys = new_model.load_state_dict( + model_state_dict, strict=True + ) + logger.info( + f"missing_keys {len(missing_keys)} {missing_keys}, unexpected_keys {len(unexpected_keys)}" + ) + logger.info( + f"Successfully load {len(model.state_dict()) - len(missing_keys)}/{len(model_state_dict)} keys!" + ) + + # Freeze vae and text encoders. + ae.requires_grad_(False) + text_enc.requires_grad_(False) + # Set model as eval. + model.train() + new_model.train() + + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + + # weight_dtype = torch.bfloat16 + weight_dtype = torch.float32 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + device = torch.device("cuda:0") + ae.to(device, dtype=torch.float32) + text_enc.to(device, dtype=weight_dtype) + model.to(device, dtype=weight_dtype) + new_model.to(device, dtype=weight_dtype) + + + # Setup data: + train_dataset = getdataset(args) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=NaViTCollate(args), + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + + for step, (x_list, input_ids, cond_mask) in enumerate(train_dataloader): + x = torch.stack(x_list) + x = x.to(device) + + + attn_mask = None + input_ids = input_ids.to(device) # B L + cond_mask = cond_mask.to(device) # B L + + with torch.no_grad(): + # Map input images to latent space + normalize latents + B, _, _ = input_ids.shape # B T+num_images L b 1+4, L + cond = torch.stack([text_enc(input_ids[i], cond_mask[i]) for i in range(B)]) # B 1+num_images L D + if args.use_image_num == 0: + x = ae.encode(x) # B C T H W + else: + videos, images = x[:, :, :-args.use_image_num], x[:, :, -args.use_image_num:] + videos = ae.encode(videos) # B C T H W + + images = rearrange(images, 'b c t h w -> (b t) c 1 h w') + images = ae.encode(images) + + images = rearrange(images, '(b t) c 1 h w -> b c t h w', t=args.use_image_num) + x = torch.cat([videos, images], dim=2) # b c 17+4, h, w + + model_kwargs = dict( + encoder_hidden_states=cond, + attention_mask=attn_mask, + encoder_attention_mask=cond_mask, + use_image_num=args.use_image_num, + ) + t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device) + # [B, C, F, H, W] + output = model(x, t, **model_kwargs)[0] + + # delete model to save memory + # del model + + + # new model branch + x_list = list(x.unbind(dim=0)) + # delete x to save memory + del x + + torch.cuda.empty_cache() + + + # [b, F, T, p*p*c] + new_output, video_ids, token_kept_ids = new_model(x_list, t, **model_kwargs) + output = list(output.unbind(dim=0)) + # [b, F, T, p*p*c] + output = pack_target_as(output, video_ids, new_model.patch_size, token_kept_ids) + + diff_max = torch.abs(output - new_output) > 1e-5 + + print("output: ", output[diff_max], "\nnew_output: ", new_output[diff_max]) + + torch.testing.assert_close(output, new_output) + print("all matched!") + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset", type=str, required=True) + parser.add_argument("--video_data", type=str, required='') + parser.add_argument("--image_data", type=str, default='') + parser.add_argument("--sample_rate", type=int, default=1) + parser.add_argument("--num_frames", type=int, default=17) + parser.add_argument("--max_image_size", type=int, default=512) + parser.add_argument("--use_img_from_vid", action="store_true") + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--model_max_length", type=int, default=300) + + parser.add_argument('--enable_8bit_t5', action='store_true') + parser.add_argument('--tile_overlap_factor', type=float, default=0.25) + parser.add_argument('--enable_tiling', action='store_true') + parser.add_argument("--compress_kv", action="store_true") + parser.add_argument("--attention_mode", type=str, choices=['xformers', 'math', 'flash'], default="xformers") + parser.add_argument('--use_rope', action='store_true') + parser.add_argument('--compress_kv_factor', type=int, default=1) + + parser.add_argument("--model", type=str, choices=list(Diffusion_models.keys()), default="Latte-XL/122") + parser.add_argument("--pretrained", type=str, default=None) + parser.add_argument("--ae", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--ae_path", type=str, default="stabilityai/sd-vae-ft-mse") + parser.add_argument("--text_encoder_name", type=str, default='DeepFloyd/t5-v1_1-xxl') + parser.add_argument("--cache_dir", type=str, default='./cache_dir') + + parser.add_argument("--num_sampling_steps", type=int, default=50) + parser.add_argument('--guidance_scale', type=float, default=5.5) + parser.add_argument("--multi_scale", action="store_true") + parser.add_argument("--enable_tracker", action="store_true") + parser.add_argument("--use_deepspeed", action="store_true") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--timestep_bias_strategy", + type=str, + default="none", + choices=["earlier", "later", "range", "none"], + help=( + "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." + " Choices: ['earlier', 'later', 'range', 'none']." + " The default is 'none', which means no bias is applied, and training proceeds normally." + " The value of 'later' will increase the frequency of the model's final training timesteps." + ), + ) + parser.add_argument( + "--timestep_bias_multiplier", + type=float, + default=1.0, + help=( + "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." + " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." + ), + ) + parser.add_argument( + "--timestep_bias_begin", + type=int, + default=0, + help=( + "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." + " Defaults to zero, which equates to having no specific bias." + ), + ) + parser.add_argument( + "--timestep_bias_end", + type=int, + default=1000, + help=( + "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." + " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." + ), + ) + parser.add_argument( + "--timestep_bias_portion", + type=float, + default=0.25, + help=( + "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." + " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" + " whether the biased portions are in the earlier or later timesteps." + ), + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=10, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--max_token_lim", + type=int, + default=1024, + help="The max token limit of NaViT training.", + ) + parser.add_argument( + "--token_dropout_rate", + type=float, + default=0, + help="The drop rate of token in NaViT training.", + ) + + args = parser.parse_args() + main(args) \ No newline at end of file