Skip to content

Commit

Permalink
Merge pull request #278 from LLaVA-VL/yhzhang/llava_video_dev
Browse files Browse the repository at this point in the history
merge for branch yhzhang/llava_video_dev
  • Loading branch information
ChunyuanLI authored Oct 4, 2024
2 parents b3a46be + c4d9ca1 commit a4c9bce
Show file tree
Hide file tree
Showing 9 changed files with 580 additions and 7 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,36 @@
</p>

# LLaVA-NeXT: Open Large Multimodal Models
[![Static Badge](https://img.shields.io/badge/llava_video-paper-green)](http://arxiv.org/abs/2410.0271)
[![Static Badge](https://img.shields.io/badge/llava_onevision-paper-green)](https://arxiv.org/abs/2408.03326)
[![llava_next-blog](https://img.shields.io/badge/llava_next-blog-green)](https://llava-vl.github.io/blog/)

[![llava_onevision-demo](https://img.shields.io/badge/llava_onevision-demo-red)](https://llava-onevision.lmms-lab.com/)
[![llava_next-video_demo](https://img.shields.io/badge/llava_video-demo-red)](https://huggingface.co/spaces/WildVision/vision-arena)
[![llava_next-interleave_demo](https://img.shields.io/badge/llava_next-interleave_demo-red)](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo)
[![llava_next-video_demo](https://img.shields.io/badge/llava_next-video_demo-red)](https://huggingface.co/spaces/WildVision/vision-arena)
[![Openbayes Demo](https://img.shields.io/static/v1?label=Demo&message=OpenBayes%E8%B4%9D%E5%BC%8F%E8%AE%A1%E7%AE%97&color=green)](https://openbayes.com/console/public/tutorials/gW0ng9jKXfO)

[![llava_video-checkpoints](https://img.shields.io/badge/llava_video-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)
[![llava_onevision-checkpoints](https://img.shields.io/badge/llava_onevision-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-onevision-66a259c3526e15166d6bba37)
[![llava_next-interleave_checkpoints](https://img.shields.io/badge/llava_next-interleave_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1)
[![llava_next-video_checkpoints](https://img.shields.io/badge/llava_next-video_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944)
[![llava_next-image_checkpoints](https://img.shields.io/badge/llava_next-image_checkpoints-blue)](https://huggingface.co/lmms-lab)

## Release Notes

- **[2024/10/04] 🔥 LLaVA-Video** (formerly LLaVA-NeXT-Video) has undergone a major upgrade! We are excited to release **LLaVA-Video-178K**, a high-quality synthetic dataset for video instruction tuning. This dataset includes:

- 178,510 caption entries
- 960,792 open-ended Q&A pairs
- 196,198 multiple-choice Q&A items

Along with this, we’re also releasing the **LLaVA-Video 7B/72B models**, which deliver competitive performance on the latest video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard), [LongVideoBench](https://longvideobench.github.io/), and [Dream-1K](https://tarsier-vlm.github.io/).

📄 **Explore more**:
- [LLaVA-Video-178K Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K): Download the dataset.
- [LLaVA-Video Models](https://huggingface.co/collections/lmms-lab/llava-video-661e86f5e8dabc3ff793c944): Access model checkpoints.
- [Paper](http://arxiv.org/abs/2410.0271): Detailed information about LLaVA-Video.
- [LLaVA-Video Documentation](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_Video_1003.md): Guidance on training, inference and evaluation.

- [2024/09/13] 🔥 **🚀 [LLaVA-OneVision-Chat](docs/LLaVA_OneVision_Chat.md)**. The new LLaVA-OV-Chat (7B/72B) significantly improves the chat experience of LLaVA-OV. 📄

![](docs/ov_chat_images/chat_results.png)
Expand Down
122 changes: 122 additions & 0 deletions docs/LLaVA_Video_1003.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# LLaVA Video

## Table of Contents

1. [Model Summary](##model-summary)
2. [Inference](##inference)
3. [Training](##training)
4. [Evaluation](##evaluation-guidance)
6. [Citation](##citation)

## Model Summary

The LLaVA-Video models are 7/72B parameter models trained on [LLaVA-Video-178K](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K) and [LLaVA-OneVision Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data), based on Qwen2 language model with a context window of 32K tokens.


## Inference

We provide the simple generation process for using our model. For more details, you could refer to [Github](https://github.com/LLaVA-VL/LLaVA-NeXT).

```python
# pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle
from PIL import Image
import requests
import copy
import torch
import sys
import warnings
from decord import VideoReader, cpu
import numpy as np
warnings.filterwarnings("ignore")
def load_video(self, video_path, max_frames_num,fps=1,force_sample=False):
if max_frames_num == 0:
return np.zeros((1, 336, 336, 3))
vr = VideoReader(video_path, ctx=cpu(0),num_threads=1)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
fps = round(vr.get_avg_fps()/fps)
frame_idx = [i for i in range(0, len(vr), fps)]
frame_time = [i/fps for i in frame_idx]
if len(frame_idx) > max_frames_num or force_sample:
sample_fps = max_frames_num
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frame_time = [i/vr.get_avg_fps() for i in frame_idx]
frame_time = ",".join([f"{i:.2f}s" for i in frame_time])
spare_frames = vr.get_batch(frame_idx).asnumpy()
# import pdb;pdb.set_trace()
return spare_frames,frame_time,video_time
pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2"
model_name = "llava_qwen"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map) # Add any other thing you want to pass in llava_model_args
model.eval()
video_path = "XXXX"
max_frames_num = "64"
video,frame_time,video_time = load_video(video_path, max_frames_num, 1, force_sample=True)
video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda().bfloat16()
video = [video]
conv_template = "qwen_1_5" # Make sure you use correct chat template for different models
time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video."
question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\nPlease describe this video in detail."
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()
input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
cont = model.generate(
input_ids,
images=video,
modalities= ["video"],
do_sample=False,
temperature=0,
max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip()
print(text_outputs)
```


## Training

[[Scripts]](/Users/zhangyuanhan/Desktop/LLaVA-NeXT/scripts/video/train): Start training models on your single-image/multi-image/video data.


## Evaluation Guidance

We use the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit to evaluate our models. Ensure you have installed the LLaVA-NeXT model files as per the instructions in the main README.md.

Install lmms-eval:

> pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
### Reproducing Evaluation Results

Our models' evaluation results can be fully reproduced using the lmms-eval toolkit. After installing lmms-eval and llava, you can run the evaluation using the following commands.

Note: These commands require flash-attn. If you prefer not to install it, disable flash-attn by adding `attn_implementation=None` to the `--model_args` parameter.

Important: Different torch versions may cause slight variations in results. By default in `lmms-eval`, the requirement for torch version is set to the latest version. In `llava` repo, the torch version is set to `2.1.2`. Torch version `2.1.2` would be stable for both `llava` and `lmms-eval`

### Evaluating LLaVA-Video on multiple datasets

We recommend the developers and researchers to thoroughly evaluate the models on more datasets to get a comprehensive understanding of their performance in different scenarios. So we provide a comprehensive list of datasets for evaluation, and welcome to incoporate more evaluation tasks. Please refer to the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for more details.

```bash
# video tasks
accelerate launch --num_processes=8 \
-m lmms_eval \
--model llava_vid \
--model_args pretrained=lmms-lab/LLaVA-Video-7B-Qwen2,conv_template=qwen_1_5,max_frames_num=64,mm_spatial_pool_mode=average \
--tasks activitynetqa,videochatgpt,nextqa_mc_test,egoschema,video_dc499,videmme,videomme_w_subtitle,perceptiontest_val_mc \
--batch_size 1 \
--log_samples \
--log_samples_suffix llava_vid \
--output_path ./logs/
```

8 changes: 6 additions & 2 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@
from llava.utils import rank0_print


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs):
kwargs["device_map"] = device_map

if load_8bit:
kwargs["load_in_8bit"] = True
elif load_4bit:
kwargs["load_in_4bit"] = True
kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
else:
elif torch_dtype == "float16":
kwargs["torch_dtype"] = torch.float16
elif torch_dtype == "bfloat16":
kwargs["torch_dtype"] = torch.bfloat16
else:
import pdb;pdb.set_trace()

if customized_config is not None:
kwargs["config"] = customized_config
Expand Down
5 changes: 3 additions & 2 deletions llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def initialize_vision_modules(self, model_args, fsdp=None):
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type


if not hasattr(self.config, 'add_faster_video'):
if model_args.add_faster_video:
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
Expand Down Expand Up @@ -227,7 +228,7 @@ def add_token_per_grid(self, image_feature):
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
if self.config.add_faster_video:
if getattr(self.config, "add_faster_video", False):
# import pdb; pdb.set_trace()
# (3584, 832, 14) -> (3584, 64, 13, 14)
image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1)
Expand Down Expand Up @@ -311,7 +312,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio
if mm_newline_position == "grid":
# Grid-wise
image_feature = self.add_token_per_grid(image_feature)
if self.config.add_faster_video:
if getattr(self.config, "add_faster_video", False):
faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
# Add a token for each frame
concat_slow_fater_token = []
Expand Down
2 changes: 2 additions & 0 deletions playground/demo/video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def run_inference(args):
else:
args.force_sample = False

# import pdb;pdb.set_trace()

if getattr(model.config, "add_time_instruction", None) is not None:
args.add_time_instruction = model.config.add_time_instruction
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ train = [
"torchvision==0.16.2",
"uvicorn",
"wandb",
"deepspeed==0.14.2",
"deepspeed==0.14.4",
"peft==0.4.0",
"accelerate>=0.29.1",
"tokenizers~=0.15.2",
Expand Down
83 changes: 83 additions & 0 deletions scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/bin/bash

# Set up the data folder
IMAGE_FOLDER="XXX"
VIDEO_FOLDER="XXX"
DATA_YAML="XXX" # e.g exp.yaml

############### Prepare Envs #################
python3 -m pip install flash-attn --no-build-isolation
alias python=python3
############### Show Envs ####################

nvidia-smi

################ Arnold Jobs ################

LLM_VERSION="Qwen/Qwen2-72B-Instruct"
LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"


BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-72B-Instruct-mlp2x_gelu-pretrain_blip558k_plain"
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"

# Stage 2
PROMPT_VERSION="qwen_1_5"
MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9"
PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-72b-ov"
echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}"
echo "MID_RUN_NAME: ${MID_RUN_NAME}"


ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \
llava/train/train_mem.py \
--deepspeed scripts/zero3.json \
--model_name_or_path $PREV_STAGE_CHECKPOINT \
--version $PROMPT_VERSION \
--data_path $DATA_YAML \
--image_folder $IMAGE_FOLDER \
--video_folder $VIDEO_FOLDER \
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
--mm_vision_tower_lr=2e-6 \
--vision_tower ${VISION_MODEL_VERSION} \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--group_by_modality_length True \
--image_aspect_ratio anyres_max_9 \
--image_grid_pinpoints "(1x1),...,(6x6)" \
--mm_patch_merge_type spatial_unpad \
--bf16 True \
--run_name $MID_RUN_NAME \
--output_dir ./work_dirs/$MID_RUN_NAME \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 32768 \
--gradient_checkpointing True \
--dataloader_num_workers 2 \
--lazy_preprocess True \
--report_to wandb \
--torch_compile True \
--torch_compile_backend "inductor" \
--dataloader_drop_last True \
--frames_upbound 32 \
--mm_newline_position grid \
--add_time_instruction True \
--force_sample True \
--mm_spatial_pool_stride 2
exit 0;
83 changes: 83 additions & 0 deletions scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/bin/bash

# Set up the data folder
IMAGE_FOLDER="XXX"
VIDEO_FOLDER="XXX"
DATA_YAML="XXX" # e.g exp.yaml

############### Prepare Envs #################
python3 -m pip install flash-attn --no-build-isolation
alias python=python3
############### Show Envs ####################

nvidia-smi

################ Arnold Jobs ################

LLM_VERSION="Qwen/Qwen2-7B-Instruct"
LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
#

BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain"
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"

# Stage 2
PROMPT_VERSION="qwen_1_5"
MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9"
PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-7b-ov"
echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}"
echo "MID_RUN_NAME: ${MID_RUN_NAME}"


ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \
llava/train/train_mem.py \
--deepspeed scripts/zero3.json \
--model_name_or_path $PREV_STAGE_CHECKPOINT \
--version $PROMPT_VERSION \
--data_path $DATA_YAML \
--image_folder $IMAGE_FOLDER \
--video_folder $VIDEO_FOLDER \
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
--mm_vision_tower_lr=2e-6 \
--vision_tower ${VISION_MODEL_VERSION} \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--group_by_modality_length True \
--image_aspect_ratio anyres_max_9 \
--image_grid_pinpoints "(1x1),...,(6x6)" \
--mm_patch_merge_type spatial_unpad \
--bf16 True \
--run_name $MID_RUN_NAME \
--output_dir ./work_dirs/$MID_RUN_NAME \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 1 \
--learning_rate 1e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 32768 \
--gradient_checkpointing True \
--dataloader_num_workers 2 \
--lazy_preprocess True \
--report_to wandb \
--torch_compile True \
--torch_compile_backend "inductor" \
--dataloader_drop_last True \
--frames_upbound 110 \
--mm_newline_position grid \
--add_time_instruction True \
--force_sample True \
--mm_spatial_pool_stride 2
exit 0;
Loading

0 comments on commit a4c9bce

Please sign in to comment.