-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
110 lines (82 loc) · 3.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from torch_scatter import scatter
from prettytable import PrettyTable
class AccuracyAnalysis():
def __init__(self):
self.values_dict = {
'ground_truth': [],
'prediction': [],
'scatter_idx': []
}
def append_values(self, value_dict):
for key, value in value_dict.items():
if key in self.values_dict:
if key == 'scatter_idx':
self.values_dict[key].append(value.expand(value_dict['ground_truth'].size(0)))
else:
self.values_dict[key].append(value)
def lists_to_tensors(self):
self.values_dict = {key: torch.cat(value, dim=0) for key, value in self.values_dict.items()}
def get_nmae(self):
return self.get_mae() / torch.max(torch.linalg.norm(self.values_dict['ground_truth'], dim=-1))
def get_mae(self):
mae = scatter(
torch.linalg.norm(self.values_dict['ground_truth'] - self.values_dict['prediction'], dim=-1),
self.values_dict['scatter_idx'],
dim=0,
reduce='mean'
)
return mae
def get_approximation_error(self):
approximation_error = torch.sqrt(scatter(
torch.linalg.norm(self.values_dict['ground_truth'] - self.values_dict['prediction'], dim=-1) ** 2,
self.values_dict['scatter_idx'],
dim=0,
reduce='sum'
) / scatter(
torch.linalg.norm(self.values_dict['ground_truth'], dim=-1) ** 2,
self.values_dict['scatter_idx'],
dim=0,
reduce='sum'
))
return approximation_error
def get_mean_cosine_similarity(self):
cosine_similarity = torch.nn.CosineSimilarity(dim=-1).forward(
self.values_dict['ground_truth'],
self.values_dict['prediction']
)
mean_cosine_similarity = scatter(
cosine_similarity,
self.values_dict['scatter_idx'],
dim=0,
reduce='mean'
)
return mean_cosine_similarity
def accuracy_table(self):
self.lists_to_tensors()
nmae = self.get_nmae()
approximation_error = self.get_approximation_error()
mean_cosine_similarity = self.get_mean_cosine_similarity()
mae = self.get_mae()
table = PrettyTable(["Metric", "Mean", "Standard Deviation"])
table.add_row([
"NMAE",
"{0:.2%}".format(torch.mean(nmae).item()),
"{0:.2%}".format(torch.std(nmae).item())
])
table.add_row([
"Approximation Error",
"{0:.2%}".format(torch.mean(approximation_error).item()),
"{0:.2%}".format(torch.std(approximation_error).item())
])
table.add_row([
"Mean Cosine Similarity",
"{:.3f}".format(torch.mean(mean_cosine_similarity).item()),
"{:.3f}".format(torch.std(mean_cosine_similarity).item())
])
table.add_row([
"MAE",
"{0:.3f}".format(torch.mean(mae).item()),
"{0:.3f}".format(torch.std(mae).item())
])
return table