-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
run_infer.py
executable file
·50 lines (45 loc) · 1.52 KB
/
run_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
""" Inference """
import os
import yaml
import numpy as np
from pprint import pprint
from attrdict import AttrDict
import paddle
import paddle.nn.functional as F
import sys
from ernievil2.transformers.multimodal import ERNIE_ViL2_base
from ernievil2.utils import reader
def do_predict(args):
"""
Inference with a file
"""
if args.device == "gpu":
place = "gpu"
else:
place = "cpu"
paddle.set_device(place)
# Define data loader
test_loader = reader.create_loader(args)
# Define model
model = ERNIE_ViL2_base(args)
model.eval()
out_file = open(args.output_file, "w", encoding="utf-8")
with paddle.no_grad():
for input_data in test_loader:
img_word, input_ids, pos_ids = input_data
img_word = paddle.concat(x=img_word, axis=0)
enc_output_img, enc_output_text = model(img_word=img_word, input_ids=input_ids, pos_ids=pos_ids)
## normalize
text_emb = F.normalize(enc_output_text).numpy()
image_emb = F.normalize(enc_output_img).numpy()
for i in range(len(enc_output_img)):
txt_str = ' '.join([str(x) for x in text_emb[i]])
img_str = ' '.join([str(x) for x in image_emb[i]])
idx_str = '1'
out_file.write("\t".join([txt_str, img_str, idx_str])+'\n')
out_file.close()
if __name__ == "__main__":
with open(sys.argv[1], 'rt') as f:
args = AttrDict(yaml.safe_load(f))
pprint(args)
do_predict(args)