Skip to content

Commit

Permalink
[Bug] Add missing 'loss_aux' and related unittest (#683)
Browse files Browse the repository at this point in the history
* fix aux_loss bug

* add related unittest
  • Loading branch information
dreamerlin authored Mar 7, 2021
1 parent 471e5ea commit ee34d95
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
8 changes: 3 additions & 5 deletions mmaction/models/recognizers/recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def forward_train(self, imgs, labels, **kwargs):
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, _ = self.neck(x, labels.squeeze())
x, loss_aux = self.neck(x, labels.squeeze())
x = x.squeeze(2)
num_segs = 1
losses.update(loss_aux)

cls_score = self.cls_head(x, num_segs)
gt_labels = labels.squeeze()
Expand All @@ -40,18 +41,15 @@ def _do_test(self, imgs):
imgs = imgs.reshape((-1, ) + imgs.shape[2:])
num_segs = imgs.shape[0] // batches

losses = dict()

x = self.extract_feat(imgs)
if hasattr(self, 'neck'):
x = [
each.reshape((-1, num_segs) +
each.shape[1:]).transpose(1, 2).contiguous()
for each in x
]
x, loss_aux = self.neck(x)
x, _ = self.neck(x)
x = x.squeeze(2)
losses.update(loss_aux)
num_segs = 1

# When using `TSNHead` or `TPNHead`, shape is [batch_size, num_classes]
Expand Down
1 change: 1 addition & 0 deletions tests/test_models/test_recognizers/test_recognizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_tpn():

losses = recognizer(imgs, gt_labels)
assert isinstance(losses, dict)
assert 'loss_aux' in losses and 'loss_cls' in losses

# Test forward test
with torch.no_grad():
Expand Down

0 comments on commit ee34d95

Please sign in to comment.