Skip to content

Commit

Permalink
Merge pull request NVlabs#43 from Efficient-Large-Model/dev/modulized…
Browse files Browse the repository at this point in the history
…_init

Refactor model initialization.
  • Loading branch information
Efficient-Large-Language-Model authored Apr 4, 2024
2 parents a111a2b + bbe444e commit eabc869
Show file tree
Hide file tree
Showing 54 changed files with 10,761 additions and 4,743 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,16 @@ conda create -n vila python=3.10 -y
conda activate vila
pip install --upgrade pip # enable PEP 660 support
# this is optional if you prefer to system built-in nvcc.
conda install -c nvidia cuda-toolkit -y
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.4.2/flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install -e .
pip install -e ".[train]"
pip install git+https://github.com/huggingface/[email protected]
cp -r ./llava/train/transformers_replace/* ~/anaconda3/envs/vila/lib/python3.10/site-packages/transformers/
pip install git+https://github.com/huggingface/[email protected]
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/
```

## Training
Expand Down
2 changes: 1 addition & 1 deletion environment_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64
pip install -e .
pip install -e ".[train]"

pip install git+https://github.com/huggingface/transformers@v4.38.1
pip install git+https://github.com/huggingface/transformers@v4.36.2
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/
7 changes: 2 additions & 5 deletions inference_test/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,8 @@ def eval_model(args, model, tokenizer, image_processor):
# use_cache=True,
stopping_criteria=[stopping_criteria],
)
input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()

print(f"Question: {query_text}")
Expand Down
2 changes: 1 addition & 1 deletion llava/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .model import LlavaLlamaForCausalLM

3 changes: 2 additions & 1 deletion llava/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> D
return sources

for source in sources:
concat_values = "".join([sentence["value"] for sentence in source])
for sid, sentence in enumerate(source):
# In multimodal conversations, we automatically prepend '<image>' at the start of the first sentence if it doesn't already contain one.
if sid == 0 and DEFAULT_IMAGE_TOKEN not in sentence["value"]:
if sid == 0 and DEFAULT_IMAGE_TOKEN not in concat_values:
sentence["value"] = f"{DEFAULT_IMAGE_TOKEN}\n" + sentence["value"]
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
sentence_chunks = [chunk.strip() for chunk in sentence["value"].split(DEFAULT_IMAGE_TOKEN)]
Expand Down
17 changes: 17 additions & 0 deletions llava/data/datasets_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ def register_datasets_mixtures():
image_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/LLaVA-CC3M-Pretrain-595K/images",
)
add_dataset(llava_1_5_mm_align)

llava_1_5_pretrain = Dataset(
dataset_name="llava_1_5_pretrain",
dataset_type="torch",
data_path="/home/yunhaof/workspace/datasets/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json",
image_path="/home/yunhaof/workspace/datasets/LLaVA-Pretrain/images",
)
add_dataset(llava_1_5_pretrain)

llava_1_5_sft = Dataset(
dataset_name="llava_1_5_sft",
dataset_type="torch",
Expand Down Expand Up @@ -256,6 +265,14 @@ def register_datasets_mixtures():
image_path="/home/yunhaof/workspace/datasets/DVQA/images",
)
add_dataset(dvqa)

dvqa_subset = Dataset(
dataset_name="dvqa_subset",
dataset_type="torch",
data_path="/home/yunhaof/workspace/datasets/DVQA/processed/DVQA_train_qa_subset100K.json",
image_path="/home/yunhaof/workspace/datasets/DVQA/images",
)
add_dataset(dvqa_subset)

ai2d = Dataset(
dataset_name="ai2d",
Expand Down
6 changes: 1 addition & 5 deletions llava/data_aug/video_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,7 @@ def get_model_output(model, image_processor, tokenizer, video_path, qs, conv_mod
stopping_criteria=[stopping_criteria]
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand Down
12 changes: 1 addition & 11 deletions llava/eval/eval_mathvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,7 @@ def eval_model(args):
stopping_criteria=stopping_criteria,
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (
(input_ids != output_ids[:, :input_token_len]).sum().item()
)
if n_diff_input_output > 0:
print(
f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
)
outputs = tokenizer.batch_decode(
output_ids[:, input_token_len:], skip_special_tokens=True
)[0]
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
res = extract_answer(outputs, d)
d["extraction"] = res
Expand Down
4 changes: 3 additions & 1 deletion llava/eval/evaluate_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def evaluate_exact_match_accuracy(entries):
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument('--dataset', type=str, default='')
parser.add_argument("--conv-mode", type=str, default="llava_v1")
parser.add_argument("--answer-dir", type=str, default="")
args = parser.parse_args()

disable_torch_init()
Expand Down Expand Up @@ -197,7 +198,8 @@ def evaluate_exact_match_accuracy(entries):

print(f"Evaluating {args.dataset} ...")
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{args.dataset}_{time_prefix}.json'
results_file = os.path.join(args.answer_dir, f'{args.dataset}_{time_prefix}.json')
os.makedirs(os.path.dirname(results_file), exist_ok=True)
json.dump(outputs, open(results_file, 'w'), ensure_ascii=False)

if ds_collections[args.dataset]['metric'] == 'relaxed_accuracy':
Expand Down
13 changes: 5 additions & 8 deletions llava/eval/mmmu_utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,18 @@ def deal_with_prompt(input_text, mm_use_im_start_end):
stopping_criteria=stopping_criteria,
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return outputs
else: # multiple images actually
raise ValueError("INVALID GENERATION FOR MULTIPLE IMAGE INPUTS")
# default behavior (random sample answer) from MMMU's offcials implementation
if sample["question_type"] == "multiple-choice":
all_choices = sample["all_choices"]
response = random.choice(all_choices)
outputs = random.choice(all_choices)
else:
response = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"
outputs = "INVALID GENERATION FOR MULTIPLE IMAGE INPUTS"

return response
return outputs


def llava_image_processor(raw_image, vis_processors=None):
Expand Down
1 change: 1 addition & 0 deletions llava/eval/model_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def eval_model(model_name, questions_file, answers_file):
temperature=0.7,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria])

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
try:
index = outputs.index(conv.sep, len(prompt))
Expand Down
6 changes: 1 addition & 5 deletions llava/eval/model_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,7 @@ def eval_model(args):
max_new_tokens=1024,
use_cache=True)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand Down
6 changes: 1 addition & 5 deletions llava/eval/model_vqa_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,7 @@ def eval_model(args):
use_cache=True,
stopping_criteria=stopping_criteria,)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()

ans_id = shortuuid.uuid()
Expand Down
6 changes: 1 addition & 5 deletions llava/eval/model_vqa_mmbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ def eval_model(args):
use_cache=True,
stopping_criteria=stopping_criteria,)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand Down
6 changes: 1 addition & 5 deletions llava/eval/model_vqa_qbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,7 @@ def eval_model(args):
use_cache=True,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand Down
12 changes: 2 additions & 10 deletions llava/eval/model_vqa_science.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@ def eval_model(args):
stopping_criteria=stopping_criteria,
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand All @@ -110,11 +106,7 @@ def eval_model(args):
use_cache=True,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand Down
6 changes: 1 addition & 5 deletions llava/eval/model_vqa_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ def get_model_output(model, image_processor, tokenizer, video_path, qs, args):
stopping_criteria=[stopping_criteria]
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[:-len(stop_str)]
Expand Down
6 changes: 1 addition & 5 deletions llava/eval/run_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,7 @@ def eval_model(args):
stopping_criteria=[stopping_criteria],
)

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids")
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
Expand Down
2 changes: 1 addition & 1 deletion llava/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
from .language_model.llava_llama import LlavaLlamaModel, LlavaLlamaConfig
# from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
from .language_model.llava_mixtral import LlavaMixtralForCausalLM, LlavaMixtralConfig
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
Expand Down
Loading

0 comments on commit eabc869

Please sign in to comment.