forked from jasonxyliu/Lang2LTL-2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
256 lines (200 loc) · 12.9 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import logging
from collections import defaultdict
import string
import spot
from srer import PROPS
from utils import load_from_file
def eval_srer(true_results_fpath, srer_out_fpath):
logging.info("***** Evaluating SRER Module")
true_outs = load_from_file(true_results_fpath)
srer_outs = load_from_file(srer_out_fpath)
ncorrects = 0
nprops2ncorrects, nprops2total = defaultdict(int), defaultdict(int)
assert len(srer_outs) == len(true_outs), f"ERROR different numbers of samples:\ntrue: {len(true_outs)}\npred: {len(srer_outs)}"
for true_out, srer_out in zip(true_outs, srer_outs):
assert srer_out["utt"].strip() == true_out["utt"].strip(), f"ERROR different utterances:\ntrue: {true_out['utt']}\npred: {srer_out['utt']}"
logging.info(f"* Command: {srer_out['utt']}")
is_correct = True
nprops = len(true_out["props"])
if len(srer_out["sre_to_preds"]) != len(true_out["sre_to_preds"]):
is_correct = False
logging.info(f"ERROR incorrect number of spatial predicates\ntrue: {true_out['sre_to_preds']}\npred: {srer_out['sre_to_preds']}")
for sre_out, preds_out in srer_out["sre_to_preds"].items():
if sre_out not in true_out["sre_to_preds"]:
is_correct = False
logging.info(f"ERROR incorrect SRE:\ntrue: {list(true_out['sre_to_preds'].keys())}\nnot contain pred: {sre_out}")
else:
preds_true = true_out["sre_to_preds"][sre_out]
if list(preds_true.keys())[0] == "None" and preds_out: # referring expression with spatial relation
logging.info(f"ERROR incorrect spatial predicate:\ntrue: {preds_true}\nnot contain pred: {preds_out}")
for (rel_true, res_true), (rel_out, res_out) in zip(preds_true.items(), preds_out.items()):
if rel_out.strip() != rel_true.strip() and rel_out not in rel_true: # e.g., pred: left of; true: to the left of
is_correct = False
logging.info(f"ERROR incorrect spatial relation\ntrue: {rel_true}\npred: {rel_out}")
res_out_lower = [re_true.lower() for re_true in res_out] # output lowercase e.g., italian resturant
res_true_lower = [re_true.lower() for re_true in res_true]
if not (len(res_out) == len(res_true) and set(res_out_lower) == set(res_true_lower)):
is_correct = False
logging.info(f"ERROR incorrect REs\ntrue: {res_true}\npred: {res_out}\n true lower: {res_true_lower}\npred lower: {res_out_lower}")
true_lifted_utt = true_out["lifted_utt"].strip().translate(str.maketrans('', '', string.punctuation))
srer_lifted_utt = srer_out["lifted_utt"].strip().translate(str.maketrans('', '', string.punctuation))
if srer_lifted_utt != true_lifted_utt:
logging.info(f"WARNING lifted commands do not exactly match\ntrue: {true_out['lifted_utt']}\npred: {srer_out['lifted_utt']}")
if len(true_lifted_utt) != len(srer_lifted_utt):
is_correct = False
logging.info(f"ERROR incorrect lifted utterances\ntrue: {true_out['lifted_utt']}\npred: {srer_out['lifted_utt']}")
else:
# NOTE: whitespace check to make sure the lifted utterances are equivalent:
whitespaces_srer = [i for i, letter in enumerate(srer_lifted_utt) if letter == ' ']
whitespaces_true = [i for i, letter in enumerate(true_lifted_utt) if letter == ' ']
if whitespaces_srer != whitespaces_true:
is_correct = False
logging.info(f"ERROR Non-matching whitespaces:\ntrue: {true_out['lifted_utt']}\npred: {srer_out['lifted_utt']}")
if is_correct:
ncorrects += 1
nprops2ncorrects[nprops] += 1
else:
logging.info("Incorrect SRER output")
nprops2total[nprops] += 1
logging.info("\n")
logging.info(f"SRER Accuracy: {ncorrects} / {len(true_outs)} = {ncorrects / len(true_outs)}\n\n")
nprops2acc = {nprops: ncorrects / nprops2total[nprops] for nprops, ncorrects in nprops2ncorrects.items()}
logging.info(f"SRER nprops vs. acc: {nprops2acc}")
nprops2acc = {nprops: (ncorrects, nprops2total[nprops]) for nprops, ncorrects in nprops2ncorrects.items()}
return nprops2acc
def eval_reg(true_results_fpath, topk, reg_out_fpath):
"""
Compute the top K accuracy of Referring Expression Grounding module.
"""
logging.info("***** Evaluating REG Module")
true_outs = load_from_file(true_results_fpath)
reg_outs = load_from_file(reg_out_fpath)
topk2acc = defaultdict(int)
total_res = 0
len2ncorrects, len2total = defaultdict(int), defaultdict(int)
assert len(reg_outs) == len(true_outs), f"ERROR different numbers of samples\ntrue: {len(true_outs)}\npred: {len(reg_outs)}"
for true_out, reg_out in zip(true_outs, reg_outs):
assert reg_out["utt"] == true_out["utt"], f"ERROR different utterances:\ntrue: {true_out['utt']}\npred: {reg_out['utt']}"
logging.info(f"* Command: {true_out['utt']}")
true_ground_sre_to_preds = true_out["grounded_sre_to_preds"]
reg_ground_sre_to_preds = reg_out["grounded_sre_to_preds"]
if len(reg_ground_sre_to_preds) != len(true_ground_sre_to_preds):
logging.info(f"ERROR incorrect number of spatial referring expression:\ntrue: {true_ground_sre_to_preds}\npred: {reg_ground_sre_to_preds}")
for sre_out, pred_out in reg_ground_sre_to_preds.items():
total_res += len(list(pred_out.values())[0])
if sre_out not in true_ground_sre_to_preds:
logging.info(f"ERROR incorrect SRE:\ntrue: {list(true_ground_sre_to_preds.keys())}\nnot contain pred: {sre_out}")
continue
else:
pred_true = true_ground_sre_to_preds[sre_out]
if len(pred_out) != len(pred_true):
logging.info(f"ERROR incorrect size of spatial predicate:\ntrue: {len(pred_true)}\npred: {len(pred_out)}")
continue
res_true = [score_re[0][1] for score_re in list(pred_true.values())[0]]
res_out = [[score_ground[1] for score_ground in grounded_res] for grounded_res in list(pred_out.values())[0]]
for re_true, res_topk in zip(res_true, res_out):
for end_idx in range(1, topk+1):
if re_true in res_topk[:end_idx]:
topk2acc[end_idx] += 1
if end_idx == topk:
len2ncorrects[len(re_true)] += 1
len2total[len(re_true)] += 1
else:
if end_idx == topk:
logging.info(f"Incorrect Top-{topk} REG: \n{sre_out}\ntrue: {re_true}\npred: {res_topk}")
len2total[len(re_true)] += 1
logging.info("\n")
for idx in range(1, topk+1):
logging.info(f"REG Top-{idx} Accuracy: {topk2acc[idx]} / {total_res} = {topk2acc[idx] / total_res}")
logging.info("\n\n")
len2acc = {nprops: ncorrects / len2total[nprops] for nprops, ncorrects in len2ncorrects.items()}
logging.info(f"REG length vs. acc: {len2acc}")
len2acc = {nprops: (ncorrects, len2total[nprops]) for nprops, ncorrects in len2ncorrects.items()}
return len2acc
def eval_spg(true_results_fpath, topk, spg_out_fpath):
"""
Compute the top K accuracy of Spatial Predicate Grounding module.
"""
logging.info("***** Evaluating SPG Module")
true_outs = load_from_file(true_results_fpath)
spg_outs = load_from_file(spg_out_fpath)
topk2acc = defaultdict(int)
total_sps = 0
assert len(spg_outs) == len(true_outs), f"ERROR different numbers of samples\ntrue: {len(true_outs)}\npred: {len(spg_outs)}"
for true_out, spg_out in zip(true_outs, spg_outs):
assert spg_out["utt"] == true_out["utt"], f"ERROR different utterances:\ntrue: {true_out['utt']}\npred: {spg_out['utt']}"
logging.info(f"* Command: {true_out['utt']}")
total_sps += len(true_out["grounded_sps"])
true_ground_sps = true_out["grounded_sps"]
spg_ground_sps = spg_out["grounded_sps"]
if len(spg_ground_sps) != len(true_ground_sps):
logging.info(f"ERROR incorrect number of spatial referring expression:\ntrue: {true_ground_sps}\npred: {spg_ground_sps}")
for sre_out, sps_topk_out in spg_ground_sps.items():
if sre_out not in true_ground_sps:
logging.info(f"ERROR incorrect SRE:\ntrue: {list(true_ground_sps.keys())}\nnot contain pred: {sre_out}")
continue
else:
sp_true = true_ground_sps[sre_out][0]
if not sps_topk_out:
logging.info(f"ERROR incorrect spatila predicate grounding size empty:\n{sre_out}\n{spg_ground_sps}")
continue
for end_idx in range(1, topk+1):
for sp_out in sps_topk_out[:end_idx]:
if len(sp_true) != len(sp_out):
logging.info(f"ERROR spatial predicates have different sizes\n{sre_out}\ntrue: {sp_true}\npred: {sp_out}")
continue
is_correct = True
for (lmk_type_true, ground_true), (lmk_type_out, ground_out) in zip(sp_true.items(), sp_out.items()):
if lmk_type_out != lmk_type_true or ground_out != ground_true:
# if lmk_type_out != lmk_type_true or not (set(ground_out) & set(ground_true)):
is_correct = False
if end_idx == 1:
logging.info(f"Incorrect Top-1 SPG:\n{sre_out}\ntrue: ({lmk_type_true}) {ground_true}\npred: ({lmk_type_out}) {ground_out}")
if is_correct:
topk2acc[end_idx] += 1
break
logging.info("\n")
for idx in range(1, topk+1):
logging.info(f"SPG Top-{idx} Accuracy: {topk2acc[idx]} / {total_sps} = {topk2acc[idx] / total_sps}")
logging.info("\n\n")
def eval_lt(true_results_fpath, lt_out_fpath):
logging.info("***** Evaluating LT")
true_outs = load_from_file(true_results_fpath)
lt_outs = load_from_file(lt_out_fpath)
ncorrects = 0
assert len(lt_outs) == len(true_outs), f"ERROR different numbers of samples\ntrue: {len(true_outs)}\npred: {len(lt_outs)}"
for true_out, lt_out in zip(true_outs, lt_outs):
assert lt_out["utt"] == true_out["utt"], f"ERROR different utterances:\ntrue: {true_out['utt']}\npred: {lt_out['utt']}"
logging.info(f"* Command: {lt_out['utt']}")
ltl_true, ltl_out = true_out["lifted_ltl"], lt_out["lifted_ltl"]
props_out = [prop for prop in PROPS if prop in ltl_out]
for prop_out, prop in zip(props_out, PROPS): # replace out of order props, e.g., G i h X G ! a -> G i b X G ! a
ltl_out = ltl_out.replace(prop_out, prop)
is_correct = True
try:
spot_correct = spot.are_equivalent(spot.formula(ltl_out), spot.formula(ltl_true))
if not spot_correct and len(ltl_out) == len(ltl_true): # invariant to order of propositions
if "lifted_symbol_map" in lt_out: # exp_full input previous module, SRER
lifted_symbol_map = lt_out["lifted_symbol_map"]
else: # exp_modular input ground truth (does not have "lifted_symbol_map" key)
lifted_symbol_map = {prop: sre for prop, sre in zip(lt_out["props"], lt_out["sre_to_preds"].keys())}
sre2prop_true = {sre.lower(): prop for prop, sre in zip(true_out["props"], true_out["sre_to_preds"].keys())}
try:
prop_out2true = {f"<{prop}>": sre2prop_true[sre.lower()] for prop, sre in lifted_symbol_map.items()}
ltl_out_reorder = ltl_out
for prop in lifted_symbol_map.keys():
ltl_out_reorder = ltl_out_reorder.replace(prop, f"<{prop}>")
for prop_out, prop in prop_out2true.items():
ltl_out_reorder = ltl_out_reorder.replace(prop_out, prop)
spot_correct = spot.are_equivalent(spot.formula(ltl_out_reorder), spot.formula(ltl_true))
except KeyError: # SRER extracted incorrect SRE
spot_correct = False
except SyntaxError:
is_correct = False
logging.info(f"ERROR incorrect lifted translation Syntax Error\ntrue: {ltl_true}\npred: {ltl_out}")
if not spot_correct:
is_correct = False
logging.info(f"ERROR incorrect lifted translation:\ntrue: {spot.formula(ltl_true)}\npred: {spot.formula(ltl_out)}")
if is_correct:
ncorrects += 1
logging.info(f"LT Accuracy: {ncorrects} / {len(true_outs)} = {ncorrects / len(true_outs)}\n\n")