-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathitq.py
107 lines (90 loc) · 2.62 KB
/
itq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import numpy as np
from sklearn.decomposition import PCA
from utils.evaluate import mean_average_precision, pr_curve
def train(
train_data,
query_data,
query_targets,
retrieval_data,
retrieval_targets,
code_length,
max_iter,
device,
topk,
):
"""
Training model.
Args
train_data(torch.Tensor): Training data.
query_data(torch.Tensor): Query data.
query_targets(torch.Tensor): Query targets.
retrieval_data(torch.Tensor): Retrieval data.
retrieval_targets(torch.Tensor): Retrieval targets.
code_length(int): Hash code length.
max_iter(int): Number of iterations.
device(torch.device): GPU or CPU.
topk(int): Calculate top k data points map.
Returns
checkpoint(dict): Checkpoint.
"""
# Initialization
query_data, query_targets, retrieval_data, retrieval_targets = query_data.to(device), query_targets.to(device), retrieval_data.to(device), retrieval_targets.to(device)
R = torch.randn(code_length, code_length).to(device)
[U, _, _] = torch.svd(R)
R = U[:, :code_length]
# PCA
pca = PCA(n_components=code_length)
V = torch.from_numpy(pca.fit_transform(train_data.numpy())).to(device)
# Training
for i in range(max_iter):
V_tilde = V @ R
B = V_tilde.sign()
[U, _, VT] = torch.svd(B.t() @ V)
R = (VT.t() @ U.t())
# Evaluate
# Generate query code and retrieval code
query_code = generate_code(query_data.cpu(), code_length, R, pca)
retrieval_code = generate_code(retrieval_data.cpu(), code_length, R, pca)
# Compute map
mAP = mean_average_precision(
query_code,
retrieval_code,
query_targets,
retrieval_targets,
device,
topk,
)
# P-R curve
P, Recall = pr_curve(
query_code,
retrieval_code,
query_targets,
retrieval_targets,
device,
)
# Save checkpoint
checkpoint = {
'qB': query_code,
'rB': retrieval_code,
'qL': query_targets,
'rL': retrieval_targets,
'pca': pca,
'rotation_matrix': R,
'P': P,
'R': Recall,
'map': mAP,
}
return checkpoint
def generate_code(data, code_length, R, pca):
"""
Generate hashing code.
Args
data(torch.Tensor): Data.
code_length(int): Hashing code length.
R(torch.Tensor): Rotration matrix.
pca(callable): PCA function.
Returns
pca_data(torch.Tensor): PCA data.
"""
return (torch.from_numpy(pca.transform(data.numpy())).to(R.device) @ R).sign()