Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vision_tower is not updated as expected #130

Open
ChenFicha opened this issue Sep 26, 2024 · 7 comments
Open

Vision_tower is not updated as expected #130

ChenFicha opened this issue Sep 26, 2024 · 7 comments

Comments

@ChenFicha
Copy link

i am trying to continues fine-tune the model. But I found that the vision_tower is not updated.
So I try to use the "Recipe-2" in Bunny-v1.1-4B.md to fine-tune Bunny with your pretrained mm_projector. I use a large lr and 10 images from "bunny_695k.json":

PRETRAIN_DIR=bunny-pretrain-phi-3-siglip-s2
...
deepspeed train.py \
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 1e-2 \
    --deepspeed ./script/deepspeed/zero3.json \
    --model_name_or_path weights/Phi-3-mini-4k-instruct/ \
    ...
    --data_path data/examples.json \
    --image_folder data/images/ \
    --vision_tower weights/siglip-so400m-patch14-384/ \
    --use_s2 True \
    --unfreeze_vision_tower True \
    ...
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    ...

I added some codes in "train.py" to make save the parameters before and after train:

    ...
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)
            
    state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
    torch.save(state_dict, "weights/before.bin")

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()
    trainer.save_state()
    
    state_dict = get_peft_state_non_lora_maybe_zero_3(model.named_parameters())
    torch.save(state_dict, "weights/after.bin")
    ...

I also extract the Bunny_v1.1_4B parameters from your weights:

    model_path = "weights/Bunny-v1_1-4B/"
    model_name = 'bunny_phi3'
    model_type = 'phi-3'
    _, model, _, _ = load_pretrained_model(model_path, None, model_name, model_type)

    vision_encoder_state_dict = model.model.vision_tower.vision_tower.vision_model.state_dict()
    torch.save(vision_encoder_state_dict, "weights/bunny_phi3_vision_encoder.bin")

    mm_projector_state_dict = model.model.mm_projector.state_dict()
    torch.save(mm_projector_state_dict, "weights/bunny_phi3_mm_projector.bin")

Then, I used the following codes to compare the parameters:

def compare_dict(model1, model2, atol=1e-5, rtol=1e-3):
    total_difference = 0.0
    for (name1, param1), (name2, param2) in zip(model1.items(), model2.items()):
        assert name1 == name2
        assert param1.shape == param2.shape
        assert len(param1) != 0
        if not torch.allclose(param1.data.float(), param2.data.float(), atol=atol, rtol=rtol):
            diff = torch.sum(torch.abs(param1.data.float() - param2.data.float()))
            total_difference += diff.item()
            print(f"Layer {name1} has a difference of {diff}")
    
    print(f"Total weight difference: {total_difference}")
    
def main():
    before = torch.load("weights/before.bin", map_location="cpu")
    before_vision_encoder = {k[62:]:v for k,v in before.items() if "vision_tower" in k}
    before_mm_projector = {k[36:]:v.half() for k,v in before.items() if "mm_projector" in k}

    after = torch.load("weights/after.bin", map_location="cpu")
    after_vision_encoder = {k[62:]:v for k,v in after.items() if "vision_tower" in k}
    after_mm_projector = {k[36:]:v for k,v in after.items() if "mm_projector" in k}

    bunny_vision_encoder = torch.load("weights/bunny_phi3_vision_encoder.bin", map_location="cpu")
    bunny_mm_projector = torch.load("weights/bunny_phi3_mm_projector.bin", map_location="cpu")

    print("\n" + "="*50)
    print(f"Vision encoder before and after training:")
    compare_dict(before_vision_encoder, after_vision_encoder)
    print("="*50 + "\n")

    print("\n" + "="*50)
    print(f"Vision encoder before training and bunny_phi3:")
    compare_dict(before_vision_encoder, bunny_vision_encoder)
    print("="*50 + "\n")

    print("\n" + "="*50)
    print(f"mm_projector before and after training:")
    compare_dict(before_mm_projector, after_mm_projector)
    print("="*50 + "\n")

    print("\n" + "="*50)
    print(f"mm_projector before training and bunny_phi3:")
    compare_dict(before_mm_projector, bunny_mm_projector)
    print("="*50 + "\n")

The results shows that seems the vision_tower is not updated even the param.require_grad = True:

==================================================
Vision encoder before and after training:
Total weight difference: 0.0
==================================================


==================================================
Vision encoder before training and bunny_phi3:
Layer embeddings.patch_embedding.weight has a difference of 2.6709694862365723
...
Layer head.mlp.fc2.bias has a difference of 0.219451904296875
Total weight difference: 2629.1700118714944
==================================================


==================================================
mm_projector before and after training:
Layer 0.weight has a difference of 391775.25
Layer 0.bias has a difference of 109.53714752197266
Layer 2.weight has a difference of 341328.8125
Layer 2.bias has a difference of 122.5583267211914
Total weight difference: 733336.1579742432
==================================================


==================================================
mm_projector before training and bunny_phi3:
Layer 0.weight has a difference of 75773.1484375
Layer 0.bias has a difference of 1.0425692796707153
Layer 2.weight has a difference of 74393.125
Layer 2.bias has a difference of 1.6810407638549805
Total weight difference: 150168.99704754353
==================================================

I am confused the vision_tower is not updated even I set --unfreeze_vision_tower True. Is there anything I missed?

@Isaachhh
Copy link
Collaborator

Isaachhh commented Oct 8, 2024

Have you followed the instructions here?

@ChenFicha
Copy link
Author

@Isaachhh
Yes, I did follow the instruction. And here I am testing continuously fine-tuning. I am testing the visual-instruction-tuning. I used the "Recipe-2" in Bunny-v1.1-4B.md.

@Isaachhh
Copy link
Collaborator

Isaachhh commented Oct 9, 2024

You may try to print all the parameters needed to be optimized here?

@zycoldness
Copy link

https://github.com/BAAI-DCAI/Bunny/blob/main/bunny/model/multimodal_encoder/siglip/siglip_encoder.py#L37

@ChenFicha
Copy link
Author

@zycoldness
You're exactly right. The problem was resolved after I commented out all the @torch.no_grad() lines. Thank you very much!

@Isaachhh
The problem has been solved, thanks to @zycoldness' advice. I had overlooked the @torch.no_grad() in siglip_encoder.py. Thank you for your help as well!

BTW, here is what I got, both the vision_tower and mm_projector are in the param list.

['base_model.model.model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.weight', 'base_model.model.model.vision_tower.vision_tower.vision_model.embeddings.patch_embedding.bias', 
... 
'base_model.model.model.vision_tower.vision_tower.vision_model.head.mlp.fc2.weight', 'base_model.model.model.vision_tower.vision_tower.vision_model.head.mlp.fc2.bias', 'base_model.model.model.mm_projector.0.weight', 'base_model.model.model.mm_projector.0.bias', 'base_model.model.model.mm_projector.2.weight', 'base_model.model.model.mm_projector.2.bias']

@Isaachhh
Copy link
Collaborator

Isaachhh commented Oct 10, 2024

That's pretty weird. As shown by you, the weights of the vision encoder before training and bunny_phi3 are different, which means the vision encoder was tuned during visual instruction tuning.

So, the current code works when I trained Bunny? It may be related to the version of the packages.

@Isaachhh Isaachhh reopened this Oct 10, 2024
@ChenFicha
Copy link
Author

ChenFicha commented Oct 15, 2024

It might be, here is my package setting that you may refer to:

Docker:

nvcr.io/nvidia/pytorch:23.12-py3

Python:

Python-3.9.17

pip:

torch==2.3.1
torchvision==0.18.1
torchaudio==2.3.1
deepspeed==0.14.4
transformers==4.42.3
notebook==7.2.1
einops==0.8.0
accelerate==0.31.0
sentencepiece==0.2.0
timm==1.0.7
peft==0.11.1
datasets==2.20.0
evaluate==0.4.2
openpyxl==3.1.5
prettytable==3.10.0
openai==1.35.13
protobuf==5.27.2
gdown==5.2.0
spacy==3.7.5
nltk==3.8.1
bitsandbytes==0.43.3

ds_report:

DeepSpeed general environment info:
torch install path ............... ['/usr/local/lib/python3.9/site-packages/torch']
torch version .................... 2.3.1+cu121
deepspeed install path ........... ['/usr/local/lib/python3.9/site-packages/deepspeed']
deepspeed info ................... 0.14.4, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.3
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 31.30 GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants