-
Notifications
You must be signed in to change notification settings - Fork 2
/
recognizer3d.py
128 lines (108 loc) · 4.3 KB
/
recognizer3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
from ..builder import RECOGNIZERS
from .base import BaseRecognizer
@RECOGNIZERS.register_module()
class Recognizer3D(BaseRecognizer):
"""3D recognizer model framework."""
def forward_train(self, imgs, labels, **kwargs):
"""Defines the computation performed at every call when training."""
assert self.with_cls_head
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
losses = dict()
x = self.extract_feat(imgs)
if self.with_neck:
x, loss_aux = self.neck(x, labels.squeeze())
losses.update(loss_aux)
cls_score = self.cls_head(x)
gt_labels = labels.squeeze()
loss_cls = self.cls_head.loss(cls_score, gt_labels, **kwargs)
losses.update(loss_cls)
return losses
def _do_test(self, imgs):
"""Defines the computation performed at every call when evaluation,
testing and gradcam."""
batches = imgs.shape[0]
num_segs = imgs.shape[1]
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
if self.max_testing_views is not None:
total_views = imgs.shape[0]
assert num_segs == total_views, (
'max_testing_views is only compatible '
'with batch_size == 1')
view_ptr = 0
feats = []
while view_ptr < total_views:
batch_imgs = imgs[view_ptr:view_ptr + self.max_testing_views]
x = self.extract_feat(batch_imgs)
if self.with_neck:
x, _ = self.neck(x)
feats.append(x)
view_ptr += self.max_testing_views
# should consider the case that feat is a tuple
if isinstance(feats[0], tuple):
len_tuple = len(feats[0])
feat = [
torch.cat([x[i] for x in feats]) for i in range(len_tuple)
]
feat = tuple(feat)
else:
feat = torch.cat(feats)
else:
feat = self.extract_feat(imgs)
if self.with_neck:
feat, _ = self.neck(feat)
if self.feature_extraction:
feat_dim = len(feat[0].size()) if isinstance(feat, tuple) else len(
feat.size())
assert feat_dim in [
5, 2
], ('Got feature of unknown architecture, '
'only 3D-CNN-like ([N, in_channels, T, H, W]), and '
'transformer-like ([N, in_channels]) features are supported.')
if feat_dim == 5: # 3D-CNN architecture
# perform spatio-temporal pooling
avg_pool = nn.AdaptiveAvgPool3d(1)
if isinstance(feat, tuple):
feat = [avg_pool(x) for x in feat]
# concat them
feat = torch.cat(feat, axis=1)
else:
feat = avg_pool(feat)
# squeeze dimensions
feat = feat.reshape((batches, num_segs, -1))
# temporal average pooling
feat = feat.mean(axis=1)
return feat
# should have cls_head if not extracting features
assert self.with_cls_head
cls_score = self.cls_head(feat)
cls_score = self.average_clip(cls_score, num_segs)
return cls_score
def forward_test(self, imgs):
"""Defines the computation performed at every call when evaluation and
testing."""
return self._do_test(imgs).cpu().numpy()
def forward_dummy(self, imgs, softmax=False):
"""Used for computing network FLOPs.
See ``tools/analysis/get_flops.py``.
Args:
imgs (torch.Tensor): Input images.
Returns:
Tensor: Class score.
"""
assert self.with_cls_head
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
x = self.extract_feat(imgs)
if self.with_neck:
x, _ = self.neck(x)
outs = self.cls_head(x)
if softmax:
outs = nn.functional.softmax(outs)
return (outs, )
def forward_gradcam(self, imgs):
"""Defines the computation performed at every call when using gradcam
utils."""
assert self.with_cls_head
return self._do_test(imgs)