-
Notifications
You must be signed in to change notification settings - Fork 0
/
gnn_reranking.py
60 lines (46 loc) · 2.11 KB
/
gnn_reranking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective
Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang
Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking
Paper: https://arxiv.org/abs/2012.07620v2
======================================================================
On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""
import torch
import numpy as np
import build_adjacency_matrix
import gnn_propagate
from sklearn.metrics.pairwise import cosine_similarity
def gnn_reranking(X_q, X_g, k1, k2):
query_num, gallery_num = X_q.shape[0], X_g.shape[0]
X_u = torch.cat((X_q, X_g), axis = 0)
original_score = torch.tensor(cosine_similarity(X_u.cpu().numpy(), X_u.cpu().numpy()))
#original_score = torch.mm(X_u, X_u.t())
#del X_u, X_q, X_g
# initial ranking list
S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True)
# stage 1
A = build_adjacency_matrix.forward(initial_rank.float().cuda())
#print('A',A.shape)
S = S * S
#S = S.cuda()
# stage 2
#A = X_u
if k2 != 1:
for i in range(2):
#A = A + A.T
A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float().cuda(), S[:, :k2].contiguous().float().cuda())
A_norm = torch.norm(A, p=2, dim=1, keepdim=True)
A = A.div(A_norm.expand_as(A))
#cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t())
#del A, S
#cosine_similarity = torch.mm(X_q, X_g.t())
cosine_similarity_ = cosine_similarity(A[:query_num,].cpu().numpy(), A[query_num:, ].cpu().numpy())
L = np.argsort(-cosine_similarity_, axis = 1)
#L = L.data.cpu().numpy()
print(L.shape)
return L