-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathforward_feature_caption.py
118 lines (106 loc) · 5.57 KB
/
forward_feature_caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from torch.utils.data import DataLoader
# from torchvision.datasets import CocoCaptions
import torchvision.transforms as transforms
from transformers import VisionEncoderDecoderModel, EncoderDecoderConfig
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
from datasets.image_caption_dataset_feature import CaptionDataset
import transformers
import argparse
from PIL import Image
import os
from nltk.translate.bleu_score import corpus_bleu
import numpy as np
from transformers import ViTModel, T5ForConditionalGeneration, ViTFeatureExtractor
from transformers import T5Tokenizer, T5ForConditionalGeneration
def forward_pass(all_loader, model, encoder_name, decoder_name = "bert"):
features = []
targets = [] #img name (测试顺序是否一样的)
outputs = []
print(decoder_name)
def hook_fn_forward(module, input, output):
features.append(input[0].detach().cpu())
if decoder_name == 'bert':
forward_hook = model.decoder.bert.encoder.layer[11].output.LayerNorm.register_forward_hook(hook_fn_forward)
elif decoder_name == "bart":
# decoder.bert.encoder.layer.23.output.LayerNorm
forward_hook = model.decoder.model.decoder.layers[5].final_layer_norm.register_forward_hook(hook_fn_forward)
elif decoder_name == "roberta":
forward_hook = model.decoder.roberta.encoder.layer[11].output.LayerNorm.register_forward_hook(hook_fn_forward)
else:
print('no hook')
model = model.eval()
with torch.no_grad():
cnt = 0
for _, (imgs, encoded_captions, length) in enumerate(all_loader):
imgs, encoded_captions = imgs.to('cuda'), encoded_captions.to('cuda')
# model.eval()
_ = model(pixel_values=imgs, labels=encoded_captions).loss
cnt += 1
forward_hook.remove()
features = torch.cat([x.mean(dim=1) for x in features])
return features.cpu()
encoder_name_list = ['vit','swinvit','swin2vit']
decoder_name_list = ['bert','bart','roberta']
for encoder_name_tmp in encoder_name_list:
for decoder_name_tmp in decoder_name_list:
if encoder_name_tmp == 'vit':
encoder_name = "google/vit-base-patch16-224-in21k"
elif encoder_name_tmp == 'swinvit':
encoder_name = "microsoft/swin-base-patch4-window7-224-in22k"
elif encoder_name_tmp == 'swin2vit':
encoder_name = "microsoft/swinv2-base-patch4-window12-192-22k"
else:
print('no encoder')
if decoder_name_tmp == 'bert':
decoder_name = "bert-base-uncased"
elif decoder_name_tmp == 'roberta':
decoder_name = "roberta-base"
elif decoder_name_tmp == 'bart':
decoder_name = "facebook/bart-base"
else:
print('no decoder')
print(encoder_name,decoder_name)
feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(encoder_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(decoder_name)
if encoder_name_tmp == 'vit':
if decoder_name_tmp == 'bert':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/vit_bert_oncoco_tune_1e-05_1e-06_10')
elif decoder_name_tmp == 'bart':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/vit_bart_oncoco_tune_0.0001_1e-06_10')
elif decoder_name_tmp == 'roberta':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/vit_roberta_oncoco_tune_1e-05_1e-06_10')
else:
print('no decoder')
elif encoder_name_tmp == 'swinvit':
if decoder_name_tmp == 'bert':
print('/data/ckp/swinvit_bert_tune_onflickr8k1e-05_0.0001_10')
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/swinvit_bert_tune_onflickr8k1e-05_0.0001_10')
elif decoder_name_tmp == 'bart':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/swinvit_bart_oncoco_tune_0.0001_1e-06_10')
elif decoder_name_tmp == 'roberta':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/swinvit_roberta_oncoco_tune_1e-05_1e-05_10')
else:
print('no decoder')
elif encoder_name_tmp == 'swin2vit':
if decoder_name_tmp == 'bert':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/swin2vit_bert_oncoco_tune_1e-05_1e-06_10')
elif decoder_name_tmp == 'bart':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/swin2vit_bart_oncoco_tune_1e-05_1e-06_10')
elif decoder_name_tmp == 'roberta':
model = VisionEncoderDecoderModel.from_pretrained('/data/ckp/swin2vit_roberta_oncoco_tune_1e-05_0.0001_10')
else:
print('no decoder')
else:
print('no decoder')
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model = model.cuda()
model = model.eval()
train_dataset = CaptionDataset(root_dir="/data/Flicker8k_Dataset/", annotations_file="/data/flickr8k/all.txt", feature_extractor = feature_extractor, tokenizer =tokenizer)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False,num_workers=4)
X_trainval_feature = forward_pass(train_loader, model, encoder_name_tmp, decoder_name_tmp)
model_npy_feature = 'feature.npy'
np.save(model_npy_feature, X_trainval_feature.numpy())
print('saved!!',encoder_name_tmp,decoder_name_tmp)
print('end')