You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Loading the model is fine, but when running inference an exception is thrown Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same:
All of this works with a more standard model like llava-hf/llava-1.5-7b-hf, because it uses CLIP as its vision tower.
The issue stems from a difference between CLIPVisionEmbedding and SiglipVisionEmbeddings. Specifically CLIPVisionEmbedding casts the input dtype before running it through the Embedding module:
Hey! Since VLMs have floating tensors, we usually have to cast them to desired dtype before calling forward, in contrast to LLMs where inputs are always in int64. So, I think you should cast your inputs to the bnb_4bit_compute_dtype first before feeding into the model, because bnb doesn't handle that case internally
System Info
pip show transformers
Name: transformers
Version: 4.45.2
Who can help?
@amyeroberts @qubvel
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The issue occurs when using a Llava model that uses SigLip as its vision tower, and loading the model using a quantization configuration. For example:
Loading the model is fine, but when running inference an exception is thrown
Input type (torch.cuda.FloatTensor) and weight type (torch.cuda.HalfTensor) should be the same
:All of this works with a more standard model like
llava-hf/llava-1.5-7b-hf
, because it uses CLIP as its vision tower.The issue stems from a difference between
CLIPVisionEmbedding
andSiglipVisionEmbeddings
. SpecificallyCLIPVisionEmbedding
casts the input dtype before running it through the Embedding module:transformers/src/transformers/models/clip/modeling_clip.py
Lines 247 to 248 in 32590b5
Whereas SigLip just hands it straight off:
transformers/src/transformers/models/siglip/modeling_siglip.py
Line 311 in 32590b5
Seems like just copying
CLIPVisionEmbedding
's behavior will fix this issue.NOTE: There may be other bugs in the chain, this is just the first breaking issue I ran into when trying to use SigLip in a quantized vision tower.
Expected behavior
SigLip should work in quantized llava vision towers, just like CLIP.
The text was updated successfully, but these errors were encountered: