From 819a61214d6be3cc1542ea73c82a1c4d65ba31e1 Mon Sep 17 00:00:00 2001 From: ZhangYuanhan-AI Date: Sun, 22 Sep 2024 12:22:30 +0000 Subject: [PATCH 1/5] add stream inference code --- llava/model/builder.py | 8 +- llava/model/llava_arch.py | 5 +- playground/demo/stream_video_demo.py | 166 ++++++++++++++++++ playground/demo/video_demo.py | 2 + .../train/SO400M_Qwen2_72B_ov_to_video_am9.sh | 113 ++++++++++++ .../train/SO400M_Qwen2_7B_ov_to_video_am9.sh | 113 ++++++++++++ 6 files changed, 403 insertions(+), 4 deletions(-) create mode 100644 playground/demo/stream_video_demo.py create mode 100755 scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh create mode 100755 scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh diff --git a/llava/model/builder.py b/llava/model/builder.py index 828a8e168..704b960a6 100755 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -24,7 +24,7 @@ from llava.utils import rank0_print -def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs): +def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", torch_dtype="float16",attn_implementation="flash_attention_2", customized_config=None, overwrite_config=None, **kwargs): kwargs["device_map"] = device_map if load_8bit: @@ -32,8 +32,12 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l elif load_4bit: kwargs["load_in_4bit"] = True kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4") - else: + elif torch_dtype == "float16": kwargs["torch_dtype"] = torch.float16 + elif torch_dtype == "bfloat16": + kwargs["torch_dtype"] = torch.bfloat16 + else: + import pdb;pdb.set_trace() if customized_config is not None: kwargs["config"] = customized_config diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index 4549593bb..db6c931d7 100755 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -93,6 +93,7 @@ def initialize_vision_modules(self, model_args, fsdp=None): self.config.mm_vision_select_feature = mm_vision_select_feature self.config.mm_patch_merge_type = mm_patch_merge_type + if not hasattr(self.config, 'add_faster_video'): if model_args.add_faster_video: embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) @@ -227,7 +228,7 @@ def add_token_per_grid(self, image_feature): image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) - if self.config.add_faster_video: + if getattr(self.config, "add_faster_video", False): # import pdb; pdb.set_trace() # (3584, 832, 14) -> (3584, 64, 13, 14) image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1) @@ -311,7 +312,7 @@ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attentio if mm_newline_position == "grid": # Grid-wise image_feature = self.add_token_per_grid(image_feature) - if self.config.add_faster_video: + if getattr(self.config, "add_faster_video", False): faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx]) # Add a token for each frame concat_slow_fater_token = [] diff --git a/playground/demo/stream_video_demo.py b/playground/demo/stream_video_demo.py new file mode 100644 index 000000000..a59836128 --- /dev/null +++ b/playground/demo/stream_video_demo.py @@ -0,0 +1,166 @@ +import numpy as np +import cv2 +import warnings +import select +import sys +import openai +import base64 + +warnings.filterwarnings("ignore") + +# Global variables for storing video frames and their respective times +video_frames = [] +frame_times = [] +history_time = 0 + + + +client = openai.Client(api_key="EMPTY", base_url="xxx") + +def encode_image(frames): + base64_frames = [] + for frame in frames: + # frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert BGR to RGB + _, buffer = cv2.imencode(".jpg", frame) + buffer = base64.b64encode(buffer).decode("utf-8") + base64_frames.append(buffer) + return base64_frames + +# Function to send frames to the server and get a response +def request_server(question, base64_frames): + messages = [{"role": "user", "content": []}] + for base64_frame in base64_frames: + frame_format = { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_frame}"}, + "modalities": "video", + } + messages[0]["content"].append(frame_format) + + prompt = {"type": "text", "text": question} + messages[0]["content"].append(prompt) + + video_request = client.chat.completions.create( + model="llava-onevision-72b-ov", + messages=messages, + temperature=0, + max_tokens=1024, + ) + + return video_request.choices[0].message.content + + +class Args: + """ + Class to store configuration arguments. + """ + def __init__(self, frame_limit=30, force_sample=False): + self.frame_limit = frame_limit # Max number of frames to retrieve + self.force_sample = force_sample # Whether to force uniform sampling + + +# Function to capture frames from the camera until the user presses Enter +def load_camera_frames_until_enter(args): + global history_time # To maintain across multiple captures + + cap = cv2.VideoCapture(0) # 0 is the ID for the default camera + if not cap.isOpened(): + print("Error: Could not access the camera.") + return None, None, None + + fps = cap.get(cv2.CAP_PROP_FPS) or 30 # Default to 30 FPS if unable to retrieve FPS + frame_count = 0 + + print("Video capturing started. Press 'Enter' in the console to stop capturing.") + + while True: + ret, frame = cap.read() + if not ret: + print("Error: Could not read frame from camera.") + break + + frame_count += 1 + cur_frame_time = frame_count / fps + + video_frames.append(frame) + frame_times.append(cur_frame_time + history_time) + + # Display the frame + cv2.imshow('Camera Feed', frame) + + # Add cv2.waitKey to ensure the window remains visible + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + # Check if user pressed 'Enter' in the console + if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: + input() # Consume the "Enter" key press + print("Video capture stopped.") + break + + cap.release() + cv2.destroyAllWindows() # Close the camera feed window + + history_time = frame_times[-1] if frame_times else history_time + + # Sample frames + total_frames = len(video_frames) + print(f"Total Frames Captured: {total_frames}") + + if total_frames > args.frame_limit: + sample_indices = np.linspace(0, total_frames - 1, args.frame_limit, dtype=int) + sampled_frames = [video_frames[i] for i in sample_indices] + sampled_times = [frame_times[i] for i in sample_indices] + else: + sampled_frames = video_frames + sampled_times = frame_times + + # import pdb; pdb.set_trace() + frame_times_str = ",".join([f"{t:.2f}s" for t in sampled_times]) + return np.array(sampled_frames), frame_times_str, history_time + + +# Function to stream video, process it, and answer a user question +def stream_camera_and_ask_question(args): + video_frames, frame_times, video_time = load_camera_frames_until_enter(args) + + if video_frames is None: + print("Error capturing video frames.") + return + + question = input("Press the query for current video: ").strip().lower() + + print("question: ", question) + image_base64 = encode_image(video_frames) + # import pdb; pdb.set_trace() + response = request_server(question, image_base64) + + print(f"Model's Answer: {response}") + print(f"Video Duration: 0 to {video_time:.2f} seconds") + print(f"Frame Times: {frame_times}") + + return response + + +# Main loop to keep the system running and waiting for user input +def main_loop(): + question = "Please describe this video." + args = Args(frame_limit=64, force_sample=True) + + while True: + answer = stream_camera_and_ask_question(args) + if answer is None: + print("Exiting the loop.") + break + + user_input = input("Press 'Enter' to capture again, or 'q' to quit: ").strip().lower() + if user_input == "q": + print("Quitting the demo.") + break + + # Close all OpenCV windows after the user quits + cv2.destroyAllWindows() + + +if __name__ == "__main__": + main_loop() diff --git a/playground/demo/video_demo.py b/playground/demo/video_demo.py index 59e7b68d9..d93f65826 100644 --- a/playground/demo/video_demo.py +++ b/playground/demo/video_demo.py @@ -153,6 +153,8 @@ def run_inference(args): else: args.force_sample = False + # import pdb;pdb.set_trace() + if getattr(model.config, "add_time_instruction", None) is not None: args.add_time_instruction = model.config.add_time_instruction else: diff --git a/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh b/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh new file mode 100755 index 000000000..d6d57d676 --- /dev/null +++ b/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh @@ -0,0 +1,113 @@ +#!/bin/bash + + +# You should complete the path of the following attributes: +PROJECT_ROOT="XXXX" +## This could a yaml file for multiple files or a json file for a single file +DATA_PATH="XXXX" +IMAGE_FOLDER="XXXX" +VIDEO_FOLDER="XXXX" + + +export PYTHONWARNINGS="ignore" + + +############### Prepare Envs ################# +cd $PROJECT_ROOT +python3 -m pip install --upgrade pip +python3 -m pip install -e ".[train]" + +python3 -m pip install ninja +python3 -m pip install flash-attn --no-build-isolation +alias python=python3 +############### Show Envs #################### + +nvidia-smi +# 取 worker0 第一个 port +ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' ')) +port=${ports[0]} +port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2222}" | awk -F',' '{print $1}')" + +echo "total workers: ${ARNOLD_WORKER_NUM}" +echo "cur worker id: ${ARNOLD_ID}" +echo "gpus per worker: ${ARNOLD_WORKER_GPU}" +echo "master ip: ${METIS_WORKER_0_HOST}" +echo "master port: ${port}" +echo "master port in cmd: ${port_in_cmd}" + +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=WARN + +PORT=26000 +GPUS="0,1,2,3,4,5,6,7" + +################ Arnold Jobs ################ + +LLM_VERSION="Qwen/Qwen2-72B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + + +# Stage For video +PROMPT_VERSION="qwen_1_5" +MID_RUN_NAME="llava_next_video-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video" +PREV_STAGE_CHECKPOINT="" +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${MID_RUN_NAME}" + + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path ${DATA_PATH} \ + --image_folder ${IMAGE_FOLDER} \ + --video_folder ${VIDEO_FOLDER} \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir ./work_dirs/$MID_RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 12768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 2 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 32 \ + --mm_newline_position grid \ + --add_time_instruction True \ + --force_sample True \ + --mm_spatial_pool_stride 2 +exit 0; \ No newline at end of file diff --git a/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh b/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh new file mode 100755 index 000000000..8819743fe --- /dev/null +++ b/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh @@ -0,0 +1,113 @@ +#!/bin/bash + + +# You should complete the path of the following attributes: +PROJECT_ROOT="XXXX" +## This could a yaml file for multiple files or a json file for a single file +DATA_PATH="XXXX" +IMAGE_FOLDER="XXXX" +VIDEO_FOLDER="XXXX" + + +export PYTHONWARNINGS="ignore" + + +############### Prepare Envs ################# +cd $PROJECT_ROOT +python3 -m pip install --upgrade pip +python3 -m pip install -e ".[train]" + +python3 -m pip install ninja +python3 -m pip install flash-attn --no-build-isolation +alias python=python3 +############### Show Envs #################### + +nvidia-smi +# 取 worker0 第一个 port +ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' ')) +port=${ports[0]} +port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2222}" | awk -F',' '{print $1}')" + +echo "total workers: ${ARNOLD_WORKER_NUM}" +echo "cur worker id: ${ARNOLD_ID}" +echo "gpus per worker: ${ARNOLD_WORKER_GPU}" +echo "master ip: ${METIS_WORKER_0_HOST}" +echo "master port: ${port}" +echo "master port in cmd: ${port_in_cmd}" + +export OMP_NUM_THREADS=8 +export NCCL_IB_DISABLE=0 +export NCCL_IB_GID_INDEX=3 +# export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE} +export NCCL_SOCKET_IFNAME=eth0 +export NCCL_DEBUG=WARN + +PORT=26000 +GPUS="0,1,2,3,4,5,6,7" + +################ Arnold Jobs ################ + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + + +# Stage For video +PROMPT_VERSION="qwen_1_5" +MID_RUN_NAME="llava_next_video-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video" +PREV_STAGE_CHECKPOINT="" +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${MID_RUN_NAME}" + + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path ${DATA_PATH} \ + --image_folder ${IMAGE_FOLDER} \ + --video_folder ${VIDEO_FOLDER} \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir ./work_dirs/$MID_RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 22768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 2 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 110 \ + --mm_newline_position grid \ + --add_time_instruction True \ + --force_sample True \ + --mm_spatial_pool_stride 2 +exit 0; \ No newline at end of file From fbd6a0745bc385bb9fd94dd96ec9e69e5ee22fc3 Mon Sep 17 00:00:00 2001 From: ZhangYuanhan-AI Date: Thu, 26 Sep 2024 03:40:40 +0000 Subject: [PATCH 2/5] update --- playground/demo/stream_video_demo.py | 166 --------------------------- pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 167 deletions(-) delete mode 100644 playground/demo/stream_video_demo.py diff --git a/playground/demo/stream_video_demo.py b/playground/demo/stream_video_demo.py deleted file mode 100644 index a59836128..000000000 --- a/playground/demo/stream_video_demo.py +++ /dev/null @@ -1,166 +0,0 @@ -import numpy as np -import cv2 -import warnings -import select -import sys -import openai -import base64 - -warnings.filterwarnings("ignore") - -# Global variables for storing video frames and their respective times -video_frames = [] -frame_times = [] -history_time = 0 - - - -client = openai.Client(api_key="EMPTY", base_url="xxx") - -def encode_image(frames): - base64_frames = [] - for frame in frames: - # frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) # Convert BGR to RGB - _, buffer = cv2.imencode(".jpg", frame) - buffer = base64.b64encode(buffer).decode("utf-8") - base64_frames.append(buffer) - return base64_frames - -# Function to send frames to the server and get a response -def request_server(question, base64_frames): - messages = [{"role": "user", "content": []}] - for base64_frame in base64_frames: - frame_format = { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{base64_frame}"}, - "modalities": "video", - } - messages[0]["content"].append(frame_format) - - prompt = {"type": "text", "text": question} - messages[0]["content"].append(prompt) - - video_request = client.chat.completions.create( - model="llava-onevision-72b-ov", - messages=messages, - temperature=0, - max_tokens=1024, - ) - - return video_request.choices[0].message.content - - -class Args: - """ - Class to store configuration arguments. - """ - def __init__(self, frame_limit=30, force_sample=False): - self.frame_limit = frame_limit # Max number of frames to retrieve - self.force_sample = force_sample # Whether to force uniform sampling - - -# Function to capture frames from the camera until the user presses Enter -def load_camera_frames_until_enter(args): - global history_time # To maintain across multiple captures - - cap = cv2.VideoCapture(0) # 0 is the ID for the default camera - if not cap.isOpened(): - print("Error: Could not access the camera.") - return None, None, None - - fps = cap.get(cv2.CAP_PROP_FPS) or 30 # Default to 30 FPS if unable to retrieve FPS - frame_count = 0 - - print("Video capturing started. Press 'Enter' in the console to stop capturing.") - - while True: - ret, frame = cap.read() - if not ret: - print("Error: Could not read frame from camera.") - break - - frame_count += 1 - cur_frame_time = frame_count / fps - - video_frames.append(frame) - frame_times.append(cur_frame_time + history_time) - - # Display the frame - cv2.imshow('Camera Feed', frame) - - # Add cv2.waitKey to ensure the window remains visible - if cv2.waitKey(1) & 0xFF == ord('q'): - break - - # Check if user pressed 'Enter' in the console - if sys.stdin in select.select([sys.stdin], [], [], 0)[0]: - input() # Consume the "Enter" key press - print("Video capture stopped.") - break - - cap.release() - cv2.destroyAllWindows() # Close the camera feed window - - history_time = frame_times[-1] if frame_times else history_time - - # Sample frames - total_frames = len(video_frames) - print(f"Total Frames Captured: {total_frames}") - - if total_frames > args.frame_limit: - sample_indices = np.linspace(0, total_frames - 1, args.frame_limit, dtype=int) - sampled_frames = [video_frames[i] for i in sample_indices] - sampled_times = [frame_times[i] for i in sample_indices] - else: - sampled_frames = video_frames - sampled_times = frame_times - - # import pdb; pdb.set_trace() - frame_times_str = ",".join([f"{t:.2f}s" for t in sampled_times]) - return np.array(sampled_frames), frame_times_str, history_time - - -# Function to stream video, process it, and answer a user question -def stream_camera_and_ask_question(args): - video_frames, frame_times, video_time = load_camera_frames_until_enter(args) - - if video_frames is None: - print("Error capturing video frames.") - return - - question = input("Press the query for current video: ").strip().lower() - - print("question: ", question) - image_base64 = encode_image(video_frames) - # import pdb; pdb.set_trace() - response = request_server(question, image_base64) - - print(f"Model's Answer: {response}") - print(f"Video Duration: 0 to {video_time:.2f} seconds") - print(f"Frame Times: {frame_times}") - - return response - - -# Main loop to keep the system running and waiting for user input -def main_loop(): - question = "Please describe this video." - args = Args(frame_limit=64, force_sample=True) - - while True: - answer = stream_camera_and_ask_question(args) - if answer is None: - print("Exiting the loop.") - break - - user_input = input("Press 'Enter' to capture again, or 'q' to quit: ").strip().lower() - if user_input == "q": - print("Quitting the demo.") - break - - # Close all OpenCV windows after the user quits - cv2.destroyAllWindows() - - -if __name__ == "__main__": - main_loop() diff --git a/pyproject.toml b/pyproject.toml index 348e3d9dc..23044a9df 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ train = [ "torchvision==0.16.2", "uvicorn", "wandb", - "deepspeed==0.14.2", + "deepspeed==0.14.4", "peft==0.4.0", "accelerate>=0.29.1", "tokenizers~=0.15.2", From ec87e87841172b9aae48de56148bba47be92732c Mon Sep 17 00:00:00 2001 From: Davidzhangyuanhan <704464079@qq.com> Date: Thu, 3 Oct 2024 14:14:45 +0800 Subject: [PATCH 3/5] update llave-video --- README.md | 19 +++++- docs/LLaVA_Video_1003.md | 122 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 docs/LLaVA_Video_1003.md diff --git a/README.md b/README.md index 8111151a7..6d9bff4b6 100755 --- a/README.md +++ b/README.md @@ -3,21 +3,36 @@

# LLaVA-NeXT: Open Large Multimodal Models +[![Static Badge](https://img.shields.io/badge/llava_video-paper-green)](https://github.com/LLaVA-VL/LLaVA-NeXT) [![Static Badge](https://img.shields.io/badge/llava_onevision-paper-green)](https://arxiv.org/abs/2408.03326) [![llava_next-blog](https://img.shields.io/badge/llava_next-blog-green)](https://llava-vl.github.io/blog/) [![llava_onevision-demo](https://img.shields.io/badge/llava_onevision-demo-red)](https://llava-onevision.lmms-lab.com/) +[![llava_next-video_demo](https://img.shields.io/badge/llava_video-demo-red)](https://huggingface.co/spaces/WildVision/vision-arena) [![llava_next-interleave_demo](https://img.shields.io/badge/llava_next-interleave_demo-red)](https://huggingface.co/spaces/lmms-lab/LLaVA-NeXT-Interleave-Demo) -[![llava_next-video_demo](https://img.shields.io/badge/llava_next-video_demo-red)](https://huggingface.co/spaces/WildVision/vision-arena) [![Openbayes Demo](https://img.shields.io/static/v1?label=Demo&message=OpenBayes%E8%B4%9D%E5%BC%8F%E8%AE%A1%E7%AE%97&color=green)](https://openbayes.com/console/public/tutorials/gW0ng9jKXfO) +[![llava_video-checkpoints](https://img.shields.io/badge/llava_video-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944) [![llava_onevision-checkpoints](https://img.shields.io/badge/llava_onevision-checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-onevision-66a259c3526e15166d6bba37) [![llava_next-interleave_checkpoints](https://img.shields.io/badge/llava_next-interleave_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-interleave-66763c55c411b340b35873d1) -[![llava_next-video_checkpoints](https://img.shields.io/badge/llava_next-video_checkpoints-blue)](https://huggingface.co/collections/lmms-lab/llava-next-video-661e86f5e8dabc3ff793c944) [![llava_next-image_checkpoints](https://img.shields.io/badge/llava_next-image_checkpoints-blue)](https://huggingface.co/lmms-lab) ## Release Notes +- **[2024/10/04] 🔥 LLaVA-Video** (formerly LLaVA-NeXT-Video) has undergone a major upgrade! We are excited to release **LLaVA-Video-178K**, a high-quality synthetic dataset for video instruction tuning. This dataset includes: + + - 178,510 caption entries + - 960,792 open-ended Q&A pairs + - 196,198 multiple-choice Q&A items + + Along with this, we’re also releasing the **LLaVA-Video 7B/72B models**, which deliver competitive performance on the latest video benchmarks, including [Video-MME](https://video-mme.github.io/home_page.html#leaderboard), [LongVideoBench](https://longvideobench.github.io/), and [Dream-1K](https://tarsier-vlm.github.io/). + + 📄 **Explore more**: + - [LLaVA-Video-178K Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K): Download the dataset. + - [LLaVA-Video Models](https://huggingface.co/collections/lmms-lab/llava-video-661e86f5e8dabc3ff793c944): Access model checkpoints. + - [Paper](https://github.com/LLaVA-VL/LLaVA-NeXT): Detailed information about LLaVA-Video. + - [LLaVA-Video Documentation](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_Video_1003.md): Guidance on training, inference and evaluation. + - [2024/09/13] 🔥 **🚀 [LLaVA-OneVision-Chat](docs/LLaVA_OneVision_Chat.md)**. The new LLaVA-OV-Chat (7B/72B) significantly improves the chat experience of LLaVA-OV. 📄 ![](docs/ov_chat_images/chat_results.png) diff --git a/docs/LLaVA_Video_1003.md b/docs/LLaVA_Video_1003.md new file mode 100644 index 000000000..754f9fd79 --- /dev/null +++ b/docs/LLaVA_Video_1003.md @@ -0,0 +1,122 @@ +# LLaVA Video + +## Table of Contents + +1. [Model Summary](##model-summary) +2. [Inference](##inference) +3. [Training](##training) +4. [Evaluation](##evaluation-guidance) +6. [Citation](##citation) + +## Model Summary + +The LLaVA-Video models are 7/72B parameter models trained on [LLaVA-Video-178K](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K) and [LLaVA-OneVision Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-OneVision-Data), based on Qwen2 language model with a context window of 32K tokens. + + +## Inference + +We provide the simple generation process for using our model. For more details, you could refer to [Github](https://github.com/LLaVA-VL/LLaVA-NeXT). + +```python +# pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git +from llava.model.builder import load_pretrained_model +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX +from llava.conversation import conv_templates, SeparatorStyle +from PIL import Image +import requests +import copy +import torch +import sys +import warnings +from decord import VideoReader, cpu +import numpy as np +warnings.filterwarnings("ignore") +def load_video(self, video_path, max_frames_num,fps=1,force_sample=False): + if max_frames_num == 0: + return np.zeros((1, 336, 336, 3)) + vr = VideoReader(video_path, ctx=cpu(0),num_threads=1) + total_frame_num = len(vr) + video_time = total_frame_num / vr.get_avg_fps() + fps = round(vr.get_avg_fps()/fps) + frame_idx = [i for i in range(0, len(vr), fps)] + frame_time = [i/fps for i in frame_idx] + if len(frame_idx) > max_frames_num or force_sample: + sample_fps = max_frames_num + uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [i/vr.get_avg_fps() for i in frame_idx] + frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) + spare_frames = vr.get_batch(frame_idx).asnumpy() + # import pdb;pdb.set_trace() + return spare_frames,frame_time,video_time +pretrained = "lmms-lab/LLaVA-Video-7B-Qwen2" +model_name = "llava_qwen" +device = "cuda" +device_map = "auto" +tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, torch_dtype="bfloat16", device_map=device_map) # Add any other thing you want to pass in llava_model_args +model.eval() +video_path = "XXXX" +max_frames_num = "64" +video,frame_time,video_time = load_video(video_path, max_frames_num, 1, force_sample=True) +video = image_processor.preprocess(video, return_tensors="pt")["pixel_values"].cuda().bfloat16() +video = [video] +conv_template = "qwen_1_5" # Make sure you use correct chat template for different models +time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." +question = DEFAULT_IMAGE_TOKEN + f"{time_instruciton}\nPlease describe this video in detail." +conv = copy.deepcopy(conv_templates[conv_template]) +conv.append_message(conv.roles[0], question) +conv.append_message(conv.roles[1], None) +prompt_question = conv.get_prompt() +input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device) +cont = model.generate( + input_ids, + images=video, + modalities= ["video"], + do_sample=False, + temperature=0, + max_new_tokens=4096, +) +text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() +print(text_outputs) +``` + + +## Training + +[[Scripts]](/Users/zhangyuanhan/Desktop/LLaVA-NeXT/scripts/video/train): Start training models on your single-image/multi-image/video data. + + +## Evaluation Guidance + +We use the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) toolkit to evaluate our models. Ensure you have installed the LLaVA-NeXT model files as per the instructions in the main README.md. + +Install lmms-eval: + +> pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git + +### Reproducing Evaluation Results + +Our models' evaluation results can be fully reproduced using the lmms-eval toolkit. After installing lmms-eval and llava, you can run the evaluation using the following commands. + +Note: These commands require flash-attn. If you prefer not to install it, disable flash-attn by adding `attn_implementation=None` to the `--model_args` parameter. + +Important: Different torch versions may cause slight variations in results. By default in `lmms-eval`, the requirement for torch version is set to the latest version. In `llava` repo, the torch version is set to `2.1.2`. Torch version `2.1.2` would be stable for both `llava` and `lmms-eval` + +### Evaluating LLaVA-Video on multiple datasets + +We recommend the developers and researchers to thoroughly evaluate the models on more datasets to get a comprehensive understanding of their performance in different scenarios. So we provide a comprehensive list of datasets for evaluation, and welcome to incoporate more evaluation tasks. Please refer to the [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) for more details. + +```bash +# video tasks +accelerate launch --num_processes=8 \ +-m lmms_eval \ +--model llava_vid \ +--model_args pretrained=lmms-lab/LLaVA-Video-7B-Qwen2,conv_template=qwen_1_5,max_frames_num=64,mm_spatial_pool_mode=average \ +--tasks activitynetqa,videochatgpt,nextqa_mc_test,egoschema,video_dc499,videmme,videomme_w_subtitle,perceptiontest_val_mc \ +--batch_size 1 \ +--log_samples \ +--log_samples_suffix llava_vid \ +--output_path ./logs/ +``` + From 853ac293eaf9f0f32ee8a904ac2a4e6d4f9d5c14 Mon Sep 17 00:00:00 2001 From: Davidzhangyuanhan <704464079@qq.com> Date: Thu, 3 Oct 2024 14:15:17 +0800 Subject: [PATCH 4/5] update llava-video --- .../train/SO400M_Qwen2_72B_ov_to_video_am9.sh | 86 ++++++ .../train/SO400M_Qwen2_7B_ov_to_video_am9.sh | 87 ++++++ scripts/video/train/exp.yaml | 263 ++++++++++++++++++ 3 files changed, 436 insertions(+) create mode 100644 scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh create mode 100644 scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh create mode 100644 scripts/video/train/exp.yaml diff --git a/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh b/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh new file mode 100644 index 000000000..9cff5f5ec --- /dev/null +++ b/scripts/video/train/SO400M_Qwen2_72B_ov_to_video_am9.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# Set up the data folder +IMAGE_FOLDER="XXX" +VIDEO_FOLDER="XXX" +DATA_YAML="XXX" # e.g exp.yaml + +############### Prepare Envs ################# +python3 -m pip install flash-attn --no-build-isolation +alias python=python3 +############### Show Envs #################### + +nvidia-smi + +################ Arnold Jobs ################ + +LLM_VERSION="Qwen/Qwen2-72B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +PROMPT_VERSION=plain +PRETRAIN_DATA_VERSION="blip558k" +############### Pretrain ################ + +BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-72B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +# Stage 2 +PROMPT_VERSION="qwen_1_5" +MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9" +PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-72b-ov" +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${MID_RUN_NAME}" + + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path $DATA_YAML \ + --image_folder $IMAGE_FOLDER \ + --video_folder $VIDEO_FOLDER \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir ./work_dirs/$MID_RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 2 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 32 \ + --mm_newline_position grid \ + --add_time_instruction True \ + --force_sample True \ + --mm_spatial_pool_stride 2 +exit 0; \ No newline at end of file diff --git a/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh b/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh new file mode 100644 index 000000000..80d82d008 --- /dev/null +++ b/scripts/video/train/SO400M_Qwen2_7B_ov_to_video_am9.sh @@ -0,0 +1,87 @@ +#!/bin/bash + +# Set up the data folder +IMAGE_FOLDER="XXX" +VIDEO_FOLDER="XXX" +DATA_YAML="XXX" # e.g exp.yaml + +############### Prepare Envs ################# +python3 -m pip install flash-attn --no-build-isolation +alias python=python3 +############### Show Envs #################### + +nvidia-smi + +################ Arnold Jobs ################ + +LLM_VERSION="Qwen/Qwen2-7B-Instruct" +LLM_VERSION_CLEAN="${LLM_VERSION//\//_}" +VISION_MODEL_VERSION="google/siglip-so400m-patch14-384" +VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}" + +PROMPT_VERSION=plain +PRETRAIN_DATA_VERSION="blip558k" +############### Pretrain ################ + +BASE_RUN_NAME="llavanext-google_siglip-so400m-patch14-384-Qwen_Qwen2-7B-Instruct-mlp2x_gelu-pretrain_blip558k_plain" +echo "BASE_RUN_NAME: ${BASE_RUN_NAME}" + +# Stage 2 +PROMPT_VERSION="qwen_1_5" +MID_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-ov_to_video_am9" +PREV_STAGE_CHECKPOINT="lmms-lab/llava-onevision-qwen2-7b-ov" +echo "PREV_STAGE_CHECKPOINT: ${PREV_STAGE_CHECKPOINT}" +echo "MID_RUN_NAME: ${MID_RUN_NAME}" + + +ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" --nnodes="${ARNOLD_WORKER_NUM}" --node_rank="${ARNOLD_ID}" --master_addr="${METIS_WORKER_0_HOST}" --master_port="${port_in_cmd}" \ + llava/train/train_mem.py \ + --deepspeed scripts/zero3.json \ + --model_name_or_path $PREV_STAGE_CHECKPOINT \ + --version $PROMPT_VERSION \ + --data_path $DATA_YAML \ + --image_folder $IMAGE_FOLDER \ + --video_folder $VIDEO_FOLDER \ + --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \ + --mm_vision_tower_lr=2e-6 \ + --vision_tower ${VISION_MODEL_VERSION} \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --group_by_modality_length True \ + --image_aspect_ratio anyres_max_9 \ + --image_grid_pinpoints "(1x1),...,(6x6)" \ + --mm_patch_merge_type spatial_unpad \ + --bf16 True \ + --run_name $MID_RUN_NAME \ + --output_dir ./work_dirs/$MID_RUN_NAME \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps 2 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 1 \ + --learning_rate 1e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 32768 \ + --gradient_checkpointing True \ + --dataloader_num_workers 2 \ + --lazy_preprocess True \ + --report_to wandb \ + --torch_compile True \ + --torch_compile_backend "inductor" \ + --dataloader_drop_last True \ + --frames_upbound 110 \ + --mm_newline_position grid \ + --add_time_instruction True \ + --force_sample True \ + --mm_spatial_pool_stride 2 + +exit 0; \ No newline at end of file diff --git a/scripts/video/train/exp.yaml b/scripts/video/train/exp.yaml new file mode 100644 index 000000000..f093733d9 --- /dev/null +++ b/scripts/video/train/exp.yaml @@ -0,0 +1,263 @@ +datasets: + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_next_fit_mix_filtered_text_wild_738590.json + sampling_strategy: "first:50%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_39k.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_wild_4v_12k.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_metagen_87358.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mavis_math_rule_geo_100000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/cambrian_filtered_gpt4vo_sp_token_fltd_max10k_checked.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/VisualWebInstruct_filtered_263589.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/visual_chat_en_26048_gpt4o_coco_checked.json + sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/gpt4o_combinations_51316.json + # sampling_strategy: "all" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/chrome_writting_train_8835.json + # sampling_strategy: "first:20%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/k12_printing_train_256646.json + # sampling_strategy: "first:1%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/iiit5k_annotations_2000.json + # sampling_strategy: "first:20%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/hme100k_train_clean_74502.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sroie_data_33626.json + # sampling_strategy: "first:1%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_a_train_2009.json + # sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/orand_car_b_train_3000.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llavar_gpt4_20k.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_azuregpt_detailed_understanding_4874.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_vqa_4404.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/infographic_azuregpt4v_1992.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_chart_1787.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/lrv_normal_gpt4v_filtered_10500.json + sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/scienceqa_nona_context_19218.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_vflan4v_20000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/allava_instruct_laion4v_50000.json + sampling_strategy: "first:30%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textocr_gpt4v_train_converted_25114.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/ai2d_train_internvl_single_12413.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/textcaps_train_21952.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_qa_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_cap_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_ie_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/ureader_new/ureader_kg_sft.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/vision_flan_filtered_186070.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/mathqa_29837.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo3k_2101.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_qa_converted_67833.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/geo170k_align_converted_60252.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4o_dataset.jsonl + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-coco-50k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-knowledge-2k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-llava-30k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4v-sam-20k.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_CLEVR-Math_5290.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_FigureQA_17597.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Geometry3K_9734.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GeoQA+_17172.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_GEOS_508.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_IconQA_22599.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_MapQA_5235.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PlotQA_5485.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_PMC-VQA_35958.json + sampling_strategy: "first:1%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_Super-CLEVR_8652.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TabMWP_22462.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_TQA_10181.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_UniGeo_11959.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VizWiz_6614.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-AS_5907.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/MathV360K_VQA-RAD_2130.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_qwen2_72b_st_300000_sp_token_fltd_299992.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_st_300000.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/magpie_pro_l3_80b_mt_300000_sp_token_fltd_299998.json + sampling_strategy: "end:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/image_textualization_dataset_filtered.json + sampling_strategy: "first:20%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/ai2d_llava_format_2434.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chart2text_26961.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/chartqa_18265_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/diagram_image_to_text_300.json + sampling_strategy: "all" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hateful_memes_8500_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/hitab_2500_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/iam_5663.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/infographic_vqa_2118_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/intergps_1280_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/mapqa_37417_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/rendered_text_10000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_sqa_8514.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wikisql_74989.json + sampling_strategy: "first:10%" + # - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/robut_wtq_38246_llava_format_filtered_4000tokens_38236.json + # sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/screen2words_15730.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tabmwp_22722.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tallyqa_98680_llava_format.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/st_vqa_17247_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/tqa_llava_format_27307.json + sampling_strategy: "first:5%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visual7w_llava_format_14366.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/visualmrc_3027.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vqarad_313_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vsr_2157_llava_format.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/vistext_9969.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/cauldron/websight_10000.json + sampling_strategy: "first:10%" + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_DEMON-FULL_filtered_311085.json + sampling_strategy: all + - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/llava_ofa_mantis-instruct_reformatted.json + sampling_strategy: all + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/academic_source_30s_v1_all.json + # sampling_strategy: all + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/0718_0_30_s_academic_mc_v0_1_all.json + # sampling_strategy: all + # - json_path: /mnt/bn/vl-research/data/llava_instruct/real_vision_flan/sharegpt4video_255000.json + # sampling_strategy: all + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_activitynetqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/0_30_s_perceptiontest_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_activitynetqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/30_60_s_perceptiontest_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_activitynetqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/1_2_m_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/sharegptvideo_qa_255k.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_academic_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_youtube_v0_1_cap.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_academic_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_academic_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_youtube_oe_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_youtube_mc_v0_1_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_nextqa_oe_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_nextqa_mc_qa.json + sampling_strategy: "all" + - json_path: /mnt/bn/tiktok-mm-3/aiic/users/wujinming/_training_data/jsons/tos/2_3_m_activitynetqa_oe_qa.json + sampling_strategy: "all" From c4d9ca1413a1679377d6c53836acdddcd67502f0 Mon Sep 17 00:00:00 2001 From: Davidzhangyuanhan <704464079@qq.com> Date: Fri, 4 Oct 2024 10:08:04 +0800 Subject: [PATCH 5/5] update --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6d9bff4b6..3a7432db5 100755 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@

# LLaVA-NeXT: Open Large Multimodal Models -[![Static Badge](https://img.shields.io/badge/llava_video-paper-green)](https://github.com/LLaVA-VL/LLaVA-NeXT) +[![Static Badge](https://img.shields.io/badge/llava_video-paper-green)](http://arxiv.org/abs/2410.0271) [![Static Badge](https://img.shields.io/badge/llava_onevision-paper-green)](https://arxiv.org/abs/2408.03326) [![llava_next-blog](https://img.shields.io/badge/llava_next-blog-green)](https://llava-vl.github.io/blog/) @@ -30,7 +30,7 @@ 📄 **Explore more**: - [LLaVA-Video-178K Dataset](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K): Download the dataset. - [LLaVA-Video Models](https://huggingface.co/collections/lmms-lab/llava-video-661e86f5e8dabc3ff793c944): Access model checkpoints. - - [Paper](https://github.com/LLaVA-VL/LLaVA-NeXT): Detailed information about LLaVA-Video. + - [Paper](http://arxiv.org/abs/2410.0271): Detailed information about LLaVA-Video. - [LLaVA-Video Documentation](https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/docs/LLaVA_Video_1003.md): Guidance on training, inference and evaluation. - [2024/09/13] 🔥 **🚀 [LLaVA-OneVision-Chat](docs/LLaVA_OneVision_Chat.md)**. The new LLaVA-OV-Chat (7B/72B) significantly improves the chat experience of LLaVA-OV. 📄