Skip to content

Latest commit

 

History

History
executable file
·
136 lines (91 loc) · 5.69 KB

README.md

File metadata and controls

executable file
·
136 lines (91 loc) · 5.69 KB
Lit-LLaMA

HawkLlama

🤗Huggingface Model | 🗂️Github | 📖Technical Report | 🎮️Demo

Zhejiang University, China

This is the official implementation of HawkLlama, an open-source multimodal large language model designed for real-world vision and language understanding applications. Our model features the following highlights.

  1. HawkLlama-8B is constructed utilizing:

    • Llama3-8B, the latest open-source large language model, trained on over 15 trillion tokens.
    • SigLIP, an enhancement over CLIP employing sigmoid loss, which achieves superior performance in image recognition.
    • An efficient vision-language connector, designed to capture high-resolution details without increasing the number of visual tokens, helps reduce the training overhead associated with high-resolution images.
  2. For model training, we utilize Llava-Pretrain dataset for pretraining and a mixed dataset specifically curated for instruction tuning, which contains both multimodal and language-only data for supervised fine-tuning.

  3. HawkLlama-8B is developed on NeMo framework, which facilitates 3D parallelism and offers scalability potential for future extension.

Our model is open-source and reproducible. Please check our technical report for more details.

Contents

Setup

  1. Create envoirment and activate it.
conda create -n hawkllama python=3.10 -y
conda activate hawkllama
  1. Clone and install this repo.
git clone https://github.com/aim-uofa/VLModel.git
cd VLModel
pip install -e .
pip install -e third_party/VLMEvalKit

Model Weights

Please refer to our HuggingFace repository to download the pretrained model weights.

Inference

We provide an example code for inference.

import torch
from PIL import Image
from HawkLlama.model import LlavaNextProcessor, LlavaNextForConditionalGeneration
from HawkLlama.utils.conversation import conv_llava_llama_3, DEFAULT_IMAGE_TOKEN

processor = LlavaNextProcessor.from_pretrained("AIM-ZJU/HawkLlama_8b")

model = LlavaNextForConditionalGeneration.from_pretrained("AIM-ZJU/HawkLlama_8b", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) 
model.to("cuda:0")

image_file = "assets/coin.png"
image = Image.open(image_file).convert('RGB')

prompt = "what coin is that?"
prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

conversation = conv_llava_llama_3.copy()
user_role_ind = 0
bot_role_ind = 1
conversation.append_message(conversation.roles[user_role_ind], prompt)
conversation.append_message(conversation.roles[bot_role_ind], "")
prompt = conversation.get_prompt()
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
output = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, max_new_tokens=2048, do_sample=False, use_cache=True)

print(processor.decode(output[0], skip_special_tokens=True))

Evaluation

Evaluate is modified based on the VLMEval codebase.

# single gpu
python third_party/VLMEvalKit/run.py --data MMBench_DEV_EN MMMU_DEV_VAL SEEDBench_IMG --model hawkllama_llama3_vlm --verbose
# multi-gpus
torchrun --nproc-per-node=8 third_party/VLMEvalKit/run.py --data MMBench_DEV_EN MMMU_DEV_VAL SEEDBench_IMG --model hawkllama_llama3_vlm --verbose

The results are shown below:

Benchmark Our Method LLaVA-Llama3-v1.1 LLaVA-Next
MMMU val 37.8 36.8 36.9
SEEDBench img 71.0 70.1 70.0
MMBench-EN dev 70.6 70.4 68.0
MMBench-CN dev 64.4 64.2 60.6
CCBench 33.9 31.6 24.7
AI2D test 65.6 70.0 67.1
ScienceQA test 76.1 72.9 70.4
HallusionBench 41.0 47.7 35.2
MMStar 43.0 45.1 38.1

Training

See train with NeMo.

Demo

Welcome to try our demo!

License

For non-commercial academic use, this project is licensed under the 2-clause BSD License. For commercial use, please contact Chunhua Shen.

Acknowledgements

We express our appreciation to the following projects for their outstanding contributions in academia and code development: LLaVA, NeMo, VLMEvalKit and xtuner.