-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathinfer.sh
91 lines (88 loc) · 2.44 KB
/
infer.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
while getopts ":d:s:" opt
do
case $opt in
d)
DATASET="$OPTARG"
;;
s)
DECODING_STRATEGY="$OPTARG"
;;
?)
echo "未知参数"
exit 1;;
esac
done
PROJECT_PATH='.'
USER_DIR=${PROJECT_PATH}/src
DATA_DIR=${PROJECT_PATH}/data/finetune/${DATASET}
SAVE_DIR=${DATA_DIR}/checkpoints
echo '-------- inference on dataset: '"$DATASET"'--------'
if [ "$DECODING_STRATEGY" == "greedy" ]; then
echo '-------- decoding strategy: greedy --------'
# inference
BEAM=1
LENPEN=1
CHECK_POINT=${SAVE_DIR}/checkpoint_best.pt
OUTPUT_FILE=${DATA_DIR}/output.txt
PRED_FILE=${DATA_DIR}/pred.txt
TASK=ved_translate
fairseq-generate "${DATA_DIR}"/binary \
--path "${CHECK_POINT}" \
--user-dir ${USER_DIR} \
--task ${TASK} \
--batch-size 64 \
--gen-subset test \
--beam ${BEAM} \
--num-workers 4 \
--no-repeat-ngram-size 3 \
--lenpen ${LENPEN} \
2>&1 >"${OUTPUT_FILE}"
grep ^H "${OUTPUT_FILE}" | cut -c 3- | sort -n | cut -f3- | sed "s/ ##//g" > "${PRED_FILE}"
elif [ "$DECODING_STRATEGY" == "beam" ]; then
echo '-------- decoding strategy: beam search --------'
# inference
BEAM=5
LENPEN=1
CHECK_POINT=${SAVE_DIR}/checkpoint_best.pt
OUTPUT_FILE=${DATA_DIR}/output.txt
PRED_FILE=${DATA_DIR}/pred.txt
TASK=ved_translate
fairseq-generate "${DATA_DIR}"/binary \
--path "${CHECK_POINT}" \
--user-dir ${USER_DIR} \
--task ${TASK} \
--batch-size 64 \
--gen-subset test \
--beam ${BEAM} \
--num-workers 4 \
--no-repeat-ngram-size 3 \
--lenpen ${LENPEN} \
2>&1 >"${OUTPUT_FILE}"
grep ^H "${OUTPUT_FILE}" | cut -c 3- | sort -n | cut -f3- | sed "s/ ##//g" > "${PRED_FILE}"
elif [ "$DECODING_STRATEGY" == "sampling" ]; then
echo '-------- decoding strategy: sampling --------'
LENPEN=1
TOP_K=100
CHECK_POINT=${SAVE_DIR}/checkpoint_best.pt
OUTPUT_FILE=${DATA_DIR}/output.txt
PRED_FILE=${DATA_DIR}/pred.txt
TASK=ved_translate
fairseq-generate "${DATA_DIR}"/binary \
--path "${CHECK_POINT}" \
--user-dir ${USER_DIR} \
--task ${TASK} \
--batch-size 64 \
--gen-subset test \
--num-workers 4 \
--no-repeat-ngram-size 3 \
--lenpen ${LENPEN} \
--sampling \
--sampling-topk ${TOP_K} \
--nbest 1 \
--beam 1 \
2>&1 >"${OUTPUT_FILE}"
grep ^H "${OUTPUT_FILE}" | cut -c 3- | sort -n | cut -f3- | sed "s/ ##//g" > "${PRED_FILE}"
else
echo 'decoding strategy '"$DECODING_STRATEGY"' not found!'
exit 1
fi