-
Notifications
You must be signed in to change notification settings - Fork 329
/
train.py
216 lines (176 loc) · 7.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import time
import numpy as np
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import lib.config.config as cfg
from lib.datasets import roidb as rdl_roidb
from lib.datasets.factory import get_imdb
from lib.datasets.imdb import imdb as imdb2
from lib.layer_utils.roi_data_layer import RoIDataLayer
from lib.nets.vgg16 import vgg16
from lib.utils.timer import Timer
try:
import cPickle as pickle
except ImportError:
import pickle
import os
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if True:
print('Appending horizontally-flipped training examples...')
imdb.append_flipped_images()
print('done')
print('Preparing training data...')
rdl_roidb.prepare_roidb(imdb)
print('done')
return imdb.roidb
def combined_roidb(imdb_names):
"""
Combine multiple roidbs
"""
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method("gt")
print('Set proposal method: {:s}'.format("gt"))
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = imdb2(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
class Train:
def __init__(self):
# Create network
if cfg.FLAGS.network == 'vgg16':
self.net = vgg16(batch_size=cfg.FLAGS.ims_per_batch)
else:
raise NotImplementedError
self.imdb, self.roidb = combined_roidb("voc_2007_trainval")
self.data_layer = RoIDataLayer(self.roidb, self.imdb.num_classes)
self.output_dir = cfg.get_output_dir(self.imdb, 'default')
def train(self):
# Create session
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
sess = tf.Session(config=tfconfig)
with sess.graph.as_default():
tf.set_random_seed(cfg.FLAGS.rng_seed)
layers = self.net.create_architecture(sess, "TRAIN", self.imdb.num_classes, tag='default')
loss = layers['total_loss']
lr = tf.Variable(cfg.FLAGS.learning_rate, trainable=False)
momentum = cfg.FLAGS.momentum
optimizer = tf.train.MomentumOptimizer(lr, momentum)
gvs = optimizer.compute_gradients(loss)
# Double bias
# Double the gradient of the bias if set
if cfg.FLAGS.double_bias:
final_gvs = []
with tf.variable_scope('Gradient_Mult'):
for grad, var in gvs:
scale = 1.
if cfg.FLAGS.double_bias and '/biases:' in var.name:
scale *= 2.
if not np.allclose(scale, 1.0):
grad = tf.multiply(grad, scale)
final_gvs.append((grad, var))
train_op = optimizer.apply_gradients(final_gvs)
else:
train_op = optimizer.apply_gradients(gvs)
# We will handle the snapshots ourselves
self.saver = tf.train.Saver(max_to_keep=100000)
# Write the train and validation information to tensorboard
# writer = tf.summary.FileWriter(self.tbdir, sess.graph)
# valwriter = tf.summary.FileWriter(self.tbvaldir)
# Load weights
# Fresh train directly from ImageNet weights
print('Loading initial model weights from {:s}'.format(cfg.FLAGS.pretrained_model))
variables = tf.global_variables()
# Initialize all variables first
sess.run(tf.variables_initializer(variables, name='init'))
var_keep_dic = self.get_variables_in_checkpoint_file(cfg.FLAGS.pretrained_model)
# Get the variables to restore, ignorizing the variables to fix
variables_to_restore = self.net.get_variables_to_restore(variables, var_keep_dic)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, cfg.FLAGS.pretrained_model)
print('Loaded.')
# Need to fix the variables before loading, so that the RGB weights are changed to BGR
# For VGG16 it also changes the convolutional weights fc6 and fc7 to
# fully connected weights
self.net.fix_variables(sess, cfg.FLAGS.pretrained_model)
print('Fixed.')
sess.run(tf.assign(lr, cfg.FLAGS.learning_rate))
last_snapshot_iter = 0
timer = Timer()
iter = last_snapshot_iter + 1
last_summary_time = time.time()
while iter < cfg.FLAGS.max_iters + 1:
# Learning rate
if iter == cfg.FLAGS.step_size + 1:
# Add snapshot here before reducing the learning rate
# self.snapshot(sess, iter)
sess.run(tf.assign(lr, cfg.FLAGS.learning_rate * cfg.FLAGS.gamma))
timer.tic()
# Get training data, one batch at a time
blobs = self.data_layer.forward()
# Compute the graph without summary
try:
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = self.net.train_step(sess, blobs, train_op)
except Exception:
# if some errors were encountered image is skipped without increasing iterations
print('image invalid, skipping')
continue
timer.toc()
iter += 1
# Display training information
if iter % (cfg.FLAGS.display) == 0:
print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
'>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n ' % \
(iter, cfg.FLAGS.max_iters, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box))
print('speed: {:.3f}s / iter'.format(timer.average_time))
if iter % cfg.FLAGS.snapshot_iterations == 0:
self.snapshot(sess, iter )
def get_variables_in_checkpoint_file(self, file_name):
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
var_to_shape_map = reader.get_variable_to_shape_map()
return var_to_shape_map
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
def snapshot(self, sess, iter):
net = self.net
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# Store the model snapshot
filename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.ckpt'
filename = os.path.join(self.output_dir, filename)
self.saver.save(sess, filename)
print('Wrote snapshot to: {:s}'.format(filename))
# Also store some meta information, random state, etc.
nfilename = 'vgg16_faster_rcnn_iter_{:d}'.format(iter) + '.pkl'
nfilename = os.path.join(self.output_dir, nfilename)
# current state of numpy random
st0 = np.random.get_state()
# current position in the database
cur = self.data_layer._cur
# current shuffled indeces of the database
perm = self.data_layer._perm
# Dump the meta info
with open(nfilename, 'wb') as fid:
pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)
return filename, nfilename
if __name__ == '__main__':
train = Train()
train.train()