-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
147 lines (126 loc) · 6.47 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import ipdb
from tqdm import tqdm
import progressbar
import torch
import ipdb
import clip
import itertools
from model.ZeroCLIP import CLIPTextGenerator
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
# custom quantization module
# from: https://github.com/Secure-AI-Systems-Group/Qu-ANTI-zation
from utils.qutils import VLMQuantizationEnabler
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
parser.add_argument("--test_path", type=str)
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
parser.add_argument("--save_name", type=str, help="the name of the saved file")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=1)
parser.add_argument("--multi_gpu", action="store_true")
parser.add_argument('--run_type',
default='caption',
nargs='?',
choices=['caption', 'arithmetics'])
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
help="Path to image for captioning")
parser.add_argument("--arithmetics_imgs", nargs="+",
default=['example_images/arithmetics/woman2.jpg',
'example_images/arithmetics/king2.jpg',
'example_images/arithmetics/man2.jpg'])
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
args = parser.parse_args()
return args
def run(args, text_generator, img_path):
image_features = text_generator.get_img_feature([img_path], None)
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
return captions
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = get_args()
device = torch.device('cuda')
save_path_prefix = args.save_path_prefix
import os
if os.path.exists(save_path_prefix):
pass
else: # recursively construct directory
os.makedirs(save_path_prefix, exist_ok=True)
# parse save name
save_name = args.save_name
#full_save_path = save_path_prefix + '/' + save_name
#print ('full save path is {}'.format(full_save_path))
print ('Loading data...')
import json
with open(args.test_path) as f:
item_list = json.load(f)
print ('Data loaded.')
print ('Number of test instances is {}'.format(len(item_list)))
bit_widths = [32, 16, 8, 4, 2]
bit_combii = itertools.product(bit_widths, bit_widths)
bit_combii = list(bit_combii)[16:25]
print("Bit Combinations:",bit_combii)
# run forward with bitwise combinations and return the probabilities
for (vbits, lbits) in bit_combii:
# reloading model for each bit combination
model, preprocess = clip.load("RN50", device=device)
full_save_path = save_path_prefix + f"v{vbits}l{lbits}" + save_name
print ('full save path is {}'.format(full_save_path))
with VLMQuantizationEnabler(model, 'per_channel_symmetric', 'per_layer_asymmetric', vbits, lbits, silent=True):
# ZeroCap generator
text_generator = CLIPTextGenerator(model, preprocess, **vars(args))
result_list = []
invalid_num = 0
print ('----------------------------------------------------------------')
#test_num = len(item_list)
test_num = 100
print ('Number of inference instances is {}'.format(test_num))
p = progressbar.ProgressBar(test_num)
p.start()
for p_idx in tqdm(range(test_num)):
p.update(p_idx)
one_test_dict = item_list[p_idx]
one_res_dict = {
'split':one_test_dict['split'],
'image_name':one_test_dict['image_name'],
#'file_path':one_test_dict['file_path'],
'captions':one_test_dict['captions']
}
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
try:
output_text = run(args, text_generator, img_path=image_full_path)
one_res_dict['prediction'] = output_text[0]
result_list.append(one_res_dict)
except Exception as error:
print(f'[!] ERROR:', error)
invalid_num += 1
print ('invalid number is {}'.format(invalid_num))
continue
p.finish()
print ('Inference completed!')
import json
with open(full_save_path, 'w') as outfile:
json.dump(result_list, outfile, indent=4)