Skip to content

Commit

Permalink
[Feature] Add Resnet2dAudio (#355)
Browse files Browse the repository at this point in the history
* [Feature] Add resnet2daudio

* [Fix] Runnable now; Add SpecAug

* [Feature] Add AVSlowfast.

* [Fix]: Add norm_cfg.

* [Fix]: Runnable now.

* [Fix] Fix some bugs.

* [Fix] Update some performance.

* Add unittest.

Some more unittests.

* Fix code style according to review.

Minor.

Fix bugs in unittest.=

Minor adding.

* Better docstrings.

Minor.

* Remove code about AVSlowfast.

* Fix unittest.

* More docstrings

* Fix README.md.

* Fix.

Minor.

* Fix unittest
  • Loading branch information
su authored Nov 30, 2020
1 parent 81450c8 commit cd46a1b
Show file tree
Hide file tree
Showing 14 changed files with 632 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# model settings
model = dict(
type='AudioRecognizer',
backbone=dict(type='ResNet', depth=50, in_channels=1, norm_eval=False),
backbone=dict(
type='ResNetAudio',
depth=50,
pretrained=None,
in_channels=1,
norm_eval=False),
cls_head=dict(
type='AudioTSNHead',
num_classes=400,
in_channels=2048,
in_channels=1024,
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
Expand Down Expand Up @@ -45,15 +50,15 @@
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
num_clips=10,
test_mode=True),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
data = dict(
videos_per_gpu=320,
videos_per_gpu=160,
workers_per_gpu=4,
train=dict(
type=dataset_type,
Expand All @@ -72,7 +77,7 @@
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.1, momentum=0.9,
type='SGD', lr=2.0, momentum=0.9,
weight_decay=0.0001) # this lr is used for 8 gpus
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
Expand All @@ -82,15 +87,16 @@
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'])
log_config = dict(
interval=20,
interval=1,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_r50_64x1x1_100e_kinetics400_audio_feature/'
work_dir = ('./work_dirs/' +
'audioonly_r50_64x1x1_100e_kinetics400_audio_feature/')
load_from = None
resume_from = None
workflow = [('train', 1)]
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
**New Features**
- Automatically add modelzoo statistics to readthedocs ([#327](https://github.com/open-mmlab/mmaction2/pull/327))
- Support GYM99 data preparation ([#331](https://github.com/open-mmlab/mmaction2/pull/331))
- Add AudioOnly Pathway from AVSlowFast. ([#355](https://github.com/open-mmlab/mmaction2/pull/355))
- Add GradCAM utils for recognizer ([#324](https://github.com/open-mmlab/mmaction2/pull/324))
- Add print config script ([#345](https://github.com/open-mmlab/mmaction2/pull/345))
- Add online motion vector decoder ([#291](https://github.com/open-mmlab/mmaction2/pull/291))
Expand Down
3 changes: 2 additions & 1 deletion mmaction/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .activitynet_dataset import ActivityNetDataset
from .audio_dataset import AudioDataset
from .audio_feature_dataset import AudioFeatureDataset
from .audio_visual_dataset import AudioVisualDataset
from .ava_dataset import AVADataset
from .base import BaseDataset
from .builder import build_dataloader, build_dataset
Expand All @@ -16,5 +17,5 @@
'VideoDataset', 'build_dataloader', 'build_dataset', 'RepeatDataset',
'RawframeDataset', 'BaseDataset', 'ActivityNetDataset', 'SSNDataset',
'HVUDataset', 'AudioDataset', 'AudioFeatureDataset', 'ImageDataset',
'RawVideoDataset', 'AVADataset'
'RawVideoDataset', 'AVADataset', 'AudioVisualDataset'
]
76 changes: 76 additions & 0 deletions mmaction/datasets/audio_visual_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os.path as osp

from .rawframe_dataset import RawframeDataset
from .registry import DATASETS


@DATASETS.register_module
class AudioVisualDataset(RawframeDataset):
"""Dataset that reads both audio and visual data, supporting both rawframes
and videos. The annotation file is same as that of the rawframe dataset,
such as:
.. code-block:: txt
some/directory-1 163 1
some/directory-2 122 1
some/directory-3 258 2
some/directory-4 234 2
some/directory-5 295 3
some/directory-6 121 3
Args:
ann_file (str): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transforms.
audio_prefix (str): Directory of the audio files.
kwargs (dict): Other keyword args for `RawframeDataset`. `video_prefix`
is also allowed if pipeline is designed for videos.
"""

def __init__(self, ann_file, pipeline, audio_prefix, **kwargs):
self.audio_prefix = audio_prefix
self.video_prefix = kwargs.pop('video_prefix', None)
self.data_prefix = kwargs.get('data_prefix', None)
super().__init__(ann_file, pipeline, **kwargs)

def load_annotations(self):
video_infos = []
with open(self.ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split()
video_info = {}
idx = 0
# idx for frame_dir
frame_dir = line_split[idx]
if self.audio_prefix is not None:
audio_path = osp.join(self.audio_prefix,
frame_dir + '.npy')
video_info['audio_path'] = audio_path
if self.video_prefix:
video_path = osp.join(self.video_prefix,
frame_dir + '.mp4')
video_info['filename'] = video_path
if self.data_prefix is not None:
frame_dir = osp.join(self.data_prefix, frame_dir)
video_info['frame_dir'] = frame_dir
idx += 1
if self.with_offset:
# idx for offset and total_frames
video_info['offset'] = int(line_split[idx])
video_info['total_frames'] = int(line_split[idx + 1])
idx += 2
else:
# idx for total_frames
video_info['total_frames'] = int(line_split[idx])
idx += 1
# idx for label[s]
label = [int(x) for x in line_split[idx:]]
assert len(label), f'missing label in line: {line}'
if self.multi_class:
assert self.num_classes is not None
video_info['label'] = label
else:
assert len(label) == 1
video_info['label'] = label[0]
video_infos.append(video_info)
return video_infos
8 changes: 4 additions & 4 deletions mmaction/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .backbones import (C3D, X3D, ResNet, ResNet2Plus1d, ResNet3d, ResNet3dCSN,
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetTIN,
ResNetTSM)
ResNet3dSlowFast, ResNet3dSlowOnly, ResNetAudio,
ResNetTIN, ResNetTSM)
from .builder import (build_backbone, build_head, build_localizer, build_loss,
build_model, build_neck, build_recognizer)
from .common import Conv2plus1d
from .common import Conv2plus1d, ConvAudio
from .heads import (AudioTSNHead, BaseHead, I3DHead, SlowFastHead, TPNHead,
TSMHead, TSNHead, X3DHead)
from .localizers import BMN, PEM, TEM
Expand All @@ -25,5 +25,5 @@
'PEM', 'TEM', 'BinaryLogisticRegressionLoss', 'BMN', 'BMNLoss',
'build_model', 'OHEMHingeLoss', 'SSNLoss', 'ResNet3dCSN', 'ResNetTIN',
'TPN', 'TPNHead', 'build_loss', 'build_neck', 'AudioRecognizer',
'AudioTSNHead', 'X3D', 'X3DHead'
'AudioTSNHead', 'X3D', 'X3DHead', 'ResNetAudio', 'ConvAudio'
]
4 changes: 3 additions & 1 deletion mmaction/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from .resnet3d_csn import ResNet3dCSN
from .resnet3d_slowfast import ResNet3dSlowFast
from .resnet3d_slowonly import ResNet3dSlowOnly
from .resnet_audio import ResNetAudio
from .resnet_tin import ResNetTIN
from .resnet_tsm import ResNetTSM
from .x3d import X3D

__all__ = [
'C3D', 'ResNet', 'ResNet3d', 'ResNetTSM', 'ResNet2Plus1d',
'ResNet3dSlowFast', 'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN', 'X3D'
'ResNet3dSlowFast', 'ResNet3dSlowOnly', 'ResNet3dCSN', 'ResNetTIN', 'X3D',
'ResNetAudio'
]
Loading

0 comments on commit cd46a1b

Please sign in to comment.