-
Notifications
You must be signed in to change notification settings - Fork 6
/
netvladlayer.py
65 lines (53 loc) · 2.44 KB
/
netvladlayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import tensorflow as tf
import numpy as np
from keras import initializers, layers
from keras.layers import Conv2D
import keras.backend as Kback
class NetVLAD(layers.Layer):
"""Creates a NetVLAD class.
"""
def __init__(self, num_clusters, assign_weight_initializer=None,
cluster_initializer=None, skip_postnorm=False, **kwargs):
self.K = num_clusters
self.assign_weight_initializer = assign_weight_initializer
self.skip_postnorm = skip_postnorm
self.outdim = 32768
super(NetVLAD, self).__init__(**kwargs)
def build(self, input_shape):
self.D = input_shape[-1]
self.C = self.add_weight(name='cluster_centers',
shape=(1,1,1,self.D,self.K),
initializer='zeros',
dtype='float32',
trainable=True)
self.conv = Conv2D(filters = self.K,kernel_size=1,strides = (1,1),
use_bias=False, padding = 'valid',
kernel_initializer='zeros')
self.conv.build(input_shape)
#might be necessary for older versions where the weights of conv are not automatically added to
#trainable_weights of the super-layer
#self._trainable_weights.append(self.conv.trainable_weights[0])
super(NetVLAD, self).build(input_shape) # Be sure to call this at the end
def call(self, inputs):
s = self.conv(inputs)
a = tf.nn.softmax(s)
# Dims used hereafter: batch, H, W, desc_coeff, cluster
# Move cluster assignment to corresponding dimension.
a = tf.expand_dims(a,-2)
# VLAD core.
v = tf.expand_dims(inputs,-1)+self.C
v = a*v
v = tf.reduce_sum(v,axis=[1,2])
v = tf.transpose(v,perm=[0,2,1])
if not self.skip_postnorm:
# Result seems to be very sensitive to the normalization method
# details, so sticking to matconvnet-style normalization here.
v = self.matconvnetNormalize(v, 1e-12)
v = tf.transpose(v, perm=[0, 2, 1])
v = self.matconvnetNormalize(tf.layers.flatten(v), 1e-12)
return v
def matconvnetNormalize(self,inputs, epsilon):
return inputs / tf.sqrt(tf.reduce_sum(inputs ** 2, axis=-1, keep_dims=True)
+ epsilon)
def compute_output_shape(self, input_shape):
return tuple([None, self.outdim])