-
Notifications
You must be signed in to change notification settings - Fork 144
/
human_pose_nn.py
504 lines (371 loc) · 18.5 KB
/
human_pose_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
import tensorflow as tf
import numpy as np
import part_detector
import settings
import utils
import os
from abc import abstractmethod
from functools import lru_cache
from scipy.stats import norm
from inception_resnet_v2 import inception_resnet_v2_arg_scope, inception_resnet_v2
import tensorflow.contrib.layers as layers
slim = tf.contrib.slim
SUMMARY_PATH = settings.LOGDIR_PATH
KEY_SUMMARIES = tf.GraphKeys.SUMMARIES
KEY_SUMMARIES_PER_JOINT = ['summary_joint_%02d' % i for i in range(16)]
class HumanPoseNN(object):
"""
The neural network used for pose estimation.
"""
def __init__(self, log_name, heatmap_size, image_size, loss_type = 'SCE', is_training = True):
tf.set_random_seed(0)
if loss_type not in { 'MSE', 'SCE' }:
raise NotImplementedError('Loss function should be either MSE or SCE!')
self.log_name = log_name
self.heatmap_size = heatmap_size
self.image_size = image_size
self.is_train = is_training
self.loss_type = loss_type
# Initialize placeholders
self.input_tensor = tf.placeholder(
dtype = tf.float32,
shape = (None, image_size, image_size, 3),
name = 'input_image')
self.present_joints = tf.placeholder(
dtype = tf.float32,
shape = (None, 16),
name = 'present_joints')
self.inside_box_joints = tf.placeholder(
dtype = tf.float32,
shape = (None, 16),
name = 'inside_box_joints')
self.desired_heatmap = tf.placeholder(
dtype = tf.float32,
shape = (None, heatmap_size, heatmap_size, 16),
name = 'desired_heatmap')
self.desired_points = tf.placeholder(
dtype = tf.float32,
shape = (None, 2, 16),
name = 'desired_points')
self.network = self.pre_process(self.input_tensor)
self.network, self.feature_tensor = self.get_network(self.network, is_training)
self.sigm_network = tf.sigmoid(self.network)
self.smoothed_sigm_network = self._get_gauss_smoothing_net(self.sigm_network, std = 0.7)
self.loss_err = self._get_loss_function(loss_type)
self.euclidean_dist = self._euclidean_dist_err()
self.euclidean_dist_per_joint = self._euclidean_dist_per_joint_err()
if is_training:
self.global_step = tf.Variable(0, name = 'global_step', trainable = False)
self.learning_rate = tf.placeholder(
dtype = tf.float32,
shape = [],
name = 'learning_rate')
self.optimize = layers.optimize_loss(loss = self.loss_err,
global_step = self.global_step,
learning_rate = self.learning_rate,
optimizer = tf.train.RMSPropOptimizer(self.learning_rate),
clip_gradients = 2.0
)
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
if log_name is not None:
self._init_summaries()
def _init_summaries(self):
if self.is_train:
logdir = os.path.join(SUMMARY_PATH, self.log_name, 'train')
self.summary_writer = tf.summary.FileWriter(logdir)
self.summary_writer_by_points = [tf.summary.FileWriter(os.path.join(logdir, 'point_%02d' % i))
for i in range(16)]
tf.scalar_summary('Average euclidean distance', self.euclidean_dist, collections = [KEY_SUMMARIES])
for i in range(16):
tf.scalar_summary('Joint euclidean distance', self.euclidean_dist_per_joint[i],
collections = [KEY_SUMMARIES_PER_JOINT[i]])
self.create_summary_from_weights()
self.ALL_SUMMARIES = tf.merge_all_summaries(KEY_SUMMARIES)
self.SUMMARIES_PER_JOINT = [tf.merge_all_summaries(KEY_SUMMARIES_PER_JOINT[i]) for i in range(16)]
else:
logdir = os.path.join(SUMMARY_PATH, self.log_name, 'test')
self.summary_writer = tf.summary.FileWriter(logdir)
def _get_loss_function(self, loss_type):
loss_dict = {
'MSE': self._loss_mse(),
'SCE': self._loss_cross_entropy()
}
return loss_dict[loss_type]
@staticmethod
@lru_cache()
def _get_gauss_filter(size = 15, std = 1.0, kernel_sum = 1.0):
samples = norm.pdf(np.linspace(-2, 2, size), 0, std)
samples /= np.sum(samples)
samples *= kernel_sum ** 0.5
samples = np.expand_dims(samples, 0)
weights = np.zeros(shape = (1, size, 16, 1), dtype = np.float32)
for i in range(16):
weights[:, :, i, 0] = samples
return weights
@staticmethod
def _get_gauss_smoothing_net(net, size = 15, std = 1.0, kernel_sum = 1.0):
filter_h = HumanPoseNN._get_gauss_filter(size, std, kernel_sum)
filter_v = filter_h.swapaxes(0, 1)
net = tf.nn.depthwise_conv2d(net, filter = filter_h, strides = [1, 1, 1, 1], padding = 'SAME',
name = 'SmoothingHorizontal')
net = tf.nn.depthwise_conv2d(net, filter = filter_v, strides = [1, 1, 1, 1], padding = 'SAME',
name = 'SmoothingVertical')
return net
def generate_output(self, shape, presented_parts, labels, sigma):
heatmap_dict = {
'MSE': utils.get_gauss_heat_map(
shape = shape, is_present = presented_parts,
mean = labels, sigma = sigma),
'SCE': utils.get_binary_heat_map(
shape = shape, is_present = presented_parts,
centers = labels, diameter = sigma)
}
return heatmap_dict[self.loss_type]
def _adjust_loss(self, loss_err):
# Shape: [batch, joints]
loss = tf.reduce_sum(loss_err, [1, 2])
# Stop error propagation of joints that are not presented
loss = tf.multiply(loss, self.present_joints)
# Compute average loss of presented joints
num_of_visible_joints = tf.reduce_sum(self.present_joints)
loss = tf.reduce_sum(loss) / num_of_visible_joints
return loss
def _loss_mse(self):
sq = tf.squared_difference(self.sigm_network, self.desired_heatmap)
loss = self._adjust_loss(sq)
return loss
def _loss_cross_entropy(self):
ce = tf.nn.sigmoid_cross_entropy_with_logits(logits = self.network, labels = self.desired_heatmap)
loss = self._adjust_loss(ce)
return loss
def _joint_highest_activations(self):
highest_activation = tf.reduce_max(self.smoothed_sigm_network, [1, 2])
return highest_activation
def _joint_positions(self):
highest_activation = tf.reduce_max(self.sigm_network, [1, 2])
x = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 1), 1)
y = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 2), 1)
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
a = tf.cast(highest_activation, tf.float32)
scale_coef = (self.image_size / self.heatmap_size)
x *= scale_coef
y *= scale_coef
out = tf.stack([y, x, a])
return out
def _euclidean_dist_err(self):
# Work only with joints that are presented inside frame
l2_dist = tf.multiply(self.euclidean_distance(), self.inside_box_joints)
# Compute average loss of presented joints
num_of_visible_joints = tf.reduce_sum(self.inside_box_joints)
l2_dist = tf.reduce_sum(l2_dist) / num_of_visible_joints
return l2_dist
def _euclidean_dist_per_joint_err(self):
# Work only with joints that are presented inside frame
l2_dist = tf.multiply(self.euclidean_distance(), self.inside_box_joints)
# Average euclidean distance of presented joints
present_joints = tf.reduce_sum(self.inside_box_joints, 0)
err = tf.reduce_sum(l2_dist, 0) / present_joints
return err
def _restore(self, checkpoint_path, variables):
saver = tf.train.Saver(variables)
saver.restore(self.sess, checkpoint_path)
def _save(self, checkpoint_path, name, variables):
if not os.path.exists(checkpoint_path):
os.mkdir(checkpoint_path)
checkpoint_name_path = os.path.join(checkpoint_path, '%s.ckpt' % name)
saver = tf.train.Saver(variables)
saver.save(self.sess, checkpoint_name_path)
def euclidean_distance(self):
x = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 1), 1)
y = tf.argmax(tf.reduce_max(self.smoothed_sigm_network, 2), 1)
x = tf.cast(x, tf.float32)
y = tf.cast(y, tf.float32)
dy = tf.squeeze(self.desired_points[:, 0, :])
dx = tf.squeeze(self.desired_points[:, 1, :])
sx = tf.squared_difference(x, dx)
sy = tf.squared_difference(y, dy)
l2_dist = tf.sqrt(sx + sy)
return l2_dist
def feed_forward(self, x):
out = self.sess.run(self.sigm_network, feed_dict = {
self.input_tensor: x
})
return out
def heat_maps(self, x):
out = self.sess.run(self.smoothed_sigm_network, feed_dict = {
self.input_tensor: x
})
return out
def feed_forward_pure(self, x):
out = self.sess.run(self.network, feed_dict = {
self.input_tensor: x
})
return out
def feed_forward_features(self, x):
out = self.sess.run(self.feature_tensor, feed_dict = {
self.input_tensor: x,
})
return out
def test_euclidean_distance(self, x, points, present_joints, inside_box_joints):
err = self.sess.run(self.euclidean_dist, feed_dict = {
self.input_tensor: x,
self.desired_points: points,
self.present_joints: present_joints,
self.inside_box_joints: inside_box_joints
})
return err
def test_joint_distances(self, x, y):
err = self.sess.run(self.euclidean_distance(), feed_dict = {
self.input_tensor: x,
self.desired_points: y
})
return err
def test_joint_activations(self, x):
err = self.sess.run(self._joint_highest_activations(), feed_dict = {
self.input_tensor: x
})
return err
def estimate_joints(self, x):
out = self.sess.run(self._joint_positions(), feed_dict = {
self.input_tensor: x
})
return out
def train(self, x, heatmaps, present_joints, learning_rate, is_inside_box):
if not self.is_train:
raise Exception('Network is not in train mode!')
self.sess.run(self.optimize, feed_dict = {
self.input_tensor: x,
self.desired_heatmap: heatmaps,
self.present_joints: present_joints,
self.learning_rate: learning_rate,
self.inside_box_joints: is_inside_box
})
def write_test_summary(self, epoch, loss):
loss_sum = tf.Summary()
loss_sum.value.add(
tag = 'Average Euclidean Distance',
simple_value = float(loss))
self.summary_writer.add_summary(loss_sum, epoch)
self.summary_writer.flush()
def write_summary(self, inp, desired_points, heatmaps, present_joints, learning_rate, is_inside_box,
write_frequency = 20, write_per_joint_frequency = 100):
step = tf.train.global_step(self.sess, self.global_step)
if step % write_frequency == 0:
feed_dict = {
self.input_tensor: inp,
self.desired_points: desired_points,
self.desired_heatmap: heatmaps,
self.present_joints: present_joints,
self.learning_rate: learning_rate,
self.inside_box_joints: is_inside_box
}
summary, loss = self.sess.run([self.ALL_SUMMARIES, self.loss_err], feed_dict = feed_dict)
self.summary_writer.add_summary(summary, step)
if step % write_per_joint_frequency == 0:
summaries = self.sess.run(self.SUMMARIES_PER_JOINT, feed_dict = feed_dict)
for i in range(16):
self.summary_writer_by_points[i].add_summary(summaries[i], step)
for i in range(16):
self.summary_writer_by_points[i].flush()
self.summary_writer.flush()
@abstractmethod
def pre_process(self, inp):
pass
@abstractmethod
def get_network(self, input_tensor, is_training):
pass
@abstractmethod
def create_summary_from_weights(self):
pass
class HumanPoseIRNetwork(HumanPoseNN):
"""
The first part of our network that exposes as an extractor of spatial features. It s derived from
Inception-Resnet-v2 architecture and modified for generating heatmaps - i.e. dense predictions of body joints.
"""
FEATURES = 32
IMAGE_SIZE = 299
HEATMAP_SIZE = 289
POINT_DIAMETER = 15
SMOOTH_SIZE = 21
def __init__(self, log_name = None, loss_type = 'SCE', is_training = False):
super().__init__(log_name, self.HEATMAP_SIZE, self.IMAGE_SIZE, loss_type, is_training)
def pre_process(self, inp):
return ((inp / 255) - 0.5) * 2.0
def get_network(self, input_tensor, is_training):
# Load pre-trained inception-resnet model
with slim.arg_scope(inception_resnet_v2_arg_scope(batch_norm_decay = 0.999, weight_decay = 0.0001)):
net, end_points = inception_resnet_v2(input_tensor, is_training = is_training)
# Adding some modification to original InceptionResnetV2 - changing scoring of AUXILIARY TOWER
weight_decay = 0.0005
with tf.variable_scope('NewInceptionResnetV2'):
with tf.variable_scope('AuxiliaryScoring'):
with slim.arg_scope([layers.convolution2d, layers.convolution2d_transpose],
weights_regularizer = slim.l2_regularizer(weight_decay),
biases_regularizer = slim.l2_regularizer(weight_decay),
activation_fn = None):
tf.summary.histogram('Last_layer/activations', net, [KEY_SUMMARIES])
# Scoring
net = slim.dropout(net, 0.7, is_training = is_training, scope = 'Dropout')
net = layers.convolution2d(net, num_outputs = self.FEATURES, kernel_size = 1, stride = 1,
scope = 'Scoring_layer')
feature = net
tf.summary.histogram('Scoring_layer/activations', net, [KEY_SUMMARIES])
# Upsampling
net = layers.convolution2d_transpose(net, num_outputs = 16, kernel_size = 17, stride = 17,
padding = 'VALID', scope = 'Upsampling_layer')
tf.summary.histogram('Upsampling_layer/activations', net, [KEY_SUMMARIES])
# Smoothing layer - separable gaussian filters
net = super()._get_gauss_smoothing_net(net, size = self.SMOOTH_SIZE, std = 1.0, kernel_sum = 0.2)
return net, feature
def restore(self, checkpoint_path, is_pre_trained_imagenet_checkpoint = False):
all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'InceptionResnetV2')
if not is_pre_trained_imagenet_checkpoint:
all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewInceptionResnetV2/AuxiliaryScoring')
super()._restore(checkpoint_path, all_vars)
def save(self, checkpoint_path, name):
all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'InceptionResnetV2')
all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewInceptionResnetV2/AuxiliaryScoring')
super()._save(checkpoint_path, name, all_vars)
def create_summary_from_weights(self):
with tf.variable_scope('NewInceptionResnetV2/AuxiliaryScoring', reuse = True):
tf.summary.histogram('Scoring_layer/biases', tf.get_variable('Scoring_layer/biases'), [KEY_SUMMARIES])
tf.summary.histogram('Upsampling_layer/biases', tf.get_variable('Upsampling_layer/biases'), [KEY_SUMMARIES])
tf.summary.histogram('Scoring_layer/weights', tf.get_variable('Scoring_layer/weights'), [KEY_SUMMARIES])
tf.summary.histogram('Upsampling_layer/weights', tf.get_variable('Upsampling_layer/weights'),
[KEY_SUMMARIES])
with tf.variable_scope('InceptionResnetV2/AuxLogits', reuse = True):
tf.summary.histogram('Last_layer/weights', tf.get_variable('Conv2d_2a_5x5/weights'), [KEY_SUMMARIES])
tf.summary.histogram('Last_layer/beta', tf.get_variable('Conv2d_2a_5x5/BatchNorm/beta'), [KEY_SUMMARIES])
tf.summary.histogram('Last_layer/moving_mean', tf.get_variable('Conv2d_2a_5x5/BatchNorm/moving_mean'),
[KEY_SUMMARIES])
class PartDetector(HumanPoseNN):
"""
Architecture of Part Detector network, as was described in https://arxiv.org/abs/1609.01743
"""
IMAGE_SIZE = 256
HEATMAP_SIZE = 256
POINT_DIAMETER = 11
def __init__(self, log_name = None, init_from_checkpoint = None, loss_type = 'SCE', is_training = False):
if init_from_checkpoint is not None:
part_detector.init_model_variables(init_from_checkpoint, is_training)
self.reuse = True
else:
self.reuse = False
super().__init__(log_name, self.HEATMAP_SIZE, self.IMAGE_SIZE, loss_type, is_training)
def pre_process(self, inp):
return inp / 255
def create_summary_from_weights(self):
pass
def restore(self, checkpoint_path):
all_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope = 'HumanPoseResnet')
all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewHumanPoseResnet/Scoring')
super()._restore(checkpoint_path, all_vars)
def save(self, checkpoint_path, name):
all_vars = tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'HumanPoseResnet')
all_vars += tf.get_collection(tf.GraphKeys.MODEL_VARIABLES, scope = 'NewHumanPoseResnet/Scoring')
super()._save(checkpoint_path, name, all_vars)
def get_network(self, input_tensor, is_training):
net_end, end_points = part_detector.human_pose_resnet(input_tensor, reuse = self.reuse, training = is_training)
return net_end, end_points['features']