Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SiglipVisionEmbeddings doesn't cast pixel_values like CLIPVisionEmbeddings does #34294

Open
4 tasks
fpgaminer opened this issue Oct 21, 2024 · 1 comment
Open
4 tasks
Labels

Comments

@fpgaminer
Copy link
Contributor

System Info

pip show transformers
Name: transformers
Version: 4.45.2

Who can help?

@amyeroberts @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The issue occurs when using a Llava model that uses SigLip as its vision tower, and loading the model using a quantization configuration. For example:

MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava"

qnt_config = BitsAndBytesConfig(load_in_4bit=True,
	bnb_4bit_quant_type="nf4",
	bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
llava_model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, quantization_config=qnt_config, device_map=0)

Loading the model is fine, but when running inference an exception is thrown Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same:

File ~/miniconda3/envs/tmpenv5/lib/python3.11/site-packages/transformers/models/siglip/modeling_siglip.py:311, in SiglipVisionEmbeddings.forward(self, pixel_values, interpolate_pos_encoding)
    [309](~/miniconda3/envs/tmpenv5/lib/python3.11/site-packages/transformers/models/siglip/modeling_siglip.py:309) def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
    [310](~/miniconda3/envs/tmpenv5/lib/python3.11/site-packages/transformers/models/siglip/modeling_siglip.py:310)     _, _, height, width = pixel_values.shape
--> [311](~/miniconda3/envs/tmpenv5/lib/python3.11/site-packages/transformers/models/siglip/modeling_siglip.py:311)     patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]
    [312](~/miniconda3/envs/tmpenv5/lib/python3.11/site-packages/transformers/models/siglip/modeling_siglip.py:312)     embeddings = patch_embeds.flatten(2).transpose(1, 2)
    [314](~/miniconda3/envs/tmpenv5/lib/python3.11/site-packages/transformers/models/siglip/modeling_siglip.py:314)     if interpolate_pos_encoding:

All of this works with a more standard model like llava-hf/llava-1.5-7b-hf, because it uses CLIP as its vision tower.

The issue stems from a difference between CLIPVisionEmbedding and SiglipVisionEmbeddings. Specifically CLIPVisionEmbedding casts the input dtype before running it through the Embedding module:

target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]

Whereas SigLip just hands it straight off:

patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]

Seems like just copying CLIPVisionEmbedding's behavior will fix this issue.

NOTE: There may be other bugs in the chain, this is just the first breaking issue I ran into when trying to use SigLip in a quantized vision tower.

Expected behavior

SigLip should work in quantized llava vision towers, just like CLIP.

@fpgaminer fpgaminer added the bug label Oct 21, 2024
fpgaminer added a commit to fpgaminer/transformers that referenced this issue Oct 21, 2024
Update SiglipVisionEmbeddings.forward to cast input to correct dtype before embedding it.
@zucchini-nlp
Copy link
Member

Hey! Since VLMs have floating tensors, we usually have to cast them to desired dtype before calling forward, in contrast to LLMs where inputs are always in int64. So, I think you should cast your inputs to the bnb_4bit_compute_dtype first before feeding into the model, because bnb doesn't handle that case internally

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants