Skip to content

Commit

Permalink
[Feature] MultiModality: Audio (open-mmlab#205)
Browse files Browse the repository at this point in the history
* First commit.

Minor.

Minor.

Minor.

Minor.

* Refactor data pipeline and audio backbone.

* Add audio melspec extracting tool.

* Enable split parts and use multi-task multi-processes.

* Add audio feature loader and selector.

Minor.

Minor.

* Add basic tools and config and minor fix.

* Add tutorial.

* Add unittest.

* Fix unnitest bugs.

Minor.

Minor.

Minor.

Minor.

Minor.

Minor.

Minor.

Minor.

* Revise unittest.

* Add soundfile lib.

Minor.

* Revise according to review.

Minor.

Minor fix.

Fix.

Minor.

* Improve coverage.

Minor.

Minor.

Minor.

Minor.

Minor.

Minor.

Minor.

* Update aliyun link.

Minor.

* Fix typos and add docstrings.

Minor.

* Fix typo.

Fix typo.

* Fix cases when audio file not found; change pad into constant padding to avoid try-catch when array is empty.

Minor fix default value.

* Modify config structure and fix bug.

* Fix random input and fix typo.

* Decrease the sample rate of test.wav so it's smaller.

* Better padding strategy.

* Add average-clips to default config.
  • Loading branch information
su authored Oct 20, 2020
1 parent 360bac7 commit cfd257a
Show file tree
Hide file tree
Showing 31 changed files with 1,966 additions and 31 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ jobs:
- name: Install Pillow
run: pip install Pillow==6.2.2
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install soundfile lib
run: sudo apt-get install -y libsndfile1
- name: Install onnx
run: pip install onnx
- name: Install librosa and soundfile
run: pip install librosa soundfile
- name: Install TurboJpeg lib
run: sudo apt-get install -y libturbojpeg
- name: Install PyTorch
Expand Down Expand Up @@ -117,6 +121,10 @@ jobs:
if: ${{matrix.torchvision < 0.5}}
- name: Install TurboJpeg lib
run: sudo apt-get install -y libturbojpeg
- name: Install soundfile lib
run: sudo apt-get install -y libsndfile1
- name: Install librosa and soundfile
run: pip install librosa soundfile
- name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install mmaction dependencies
Expand Down
63 changes: 63 additions & 0 deletions configs/recognition_audio/resnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Resnet for Audio

## Introduction
```
@article{xiao2020audiovisual,
title={Audiovisual SlowFast Networks for Video Recognition},
author={Xiao, Fanyi and Lee, Yong Jae and Grauman, Kristen and Malik, Jitendra and Feichtenhofer, Christoph},
journal={arXiv preprint arXiv:2001.08740},
year={2020}
}
```

## Model Zoo

### Kinetics-400

|config | n_fft | gpus | backbone |pretrain| top1 acc| top5 acc | inference_time(video/s) | gpu_mem(M)| ckpt | log| json|
|:--|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|[tsn_r18_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/resnet/tsn_r18_64x1x1_100e_kinetics400_audio_feature.py)|1024|8| ResNet18 | None |19.7|35.75|x|1897|[ckpt](https://download.openmmlab.com/mmaction/recognition/audio_recognition/tsn_r18_64x1x1_100e_kinetics400_audio_feature/tsn_r18_64x1x1_100e_kinetics400_audio_feature_20201012-bf34df6c.pth)|[log](https://download.openmmlab.com/mmaction/recognition/audio_recognition/tsn_r18_64x1x1_100e_kinetics400_audio_feature/20201010_144630.log)|[json](https://download.openmmlab.com/mmaction/recognition/audio_recognition/tsn_r18_64x1x1_100e_kinetics400_audio_feature/20201010_144630.log.json)|
|[tsn_r18_64x1x1_100e_kinetics400_audio_feature](/configs/recognition_audio/resnet/tsn_r18_64x1x1_100e_kinetics400_audio_feature.py) + [tsn_r50_1x1x3_100e_kinetics400_rgb](/configs/recognition/tsn/tsn_r50_1x1x3_100e_kinetics400_rgb.py)|1024|8| ResNet(18+50) | None |70.01|88.71|x|x|x|x|x|

Notes:

1. The **gpus** indicates the number of gpus we used to get the checkpoint. It is noteworthy that the configs we provide are used for 8 gpus as default.
According to the [Linear Scaling Rule](https://arxiv.org/abs/1706.02677), you may set the learning rate proportional to the batch size if you use different GPUs or videos per GPU,
e.g., lr=0.01 for 4 GPUs * 2 video/gpu and lr=0.08 for 16 GPUs * 4 video/gpu.
2. The **inference_time** is got by this [benchmark script](/tools/analysis/benchmark.py), where we use the sampling frames strategy of the test setting and only care about the model inference time,
not including the IO time and pre-processing time. For each setting, we use 1 gpu and set batch size (videos per gpu) to 1 to calculate the inference time.
3. The values in columns named after "reference" are the results got by training on the original repo, using the same model settings.

For more details on data preparation, you can refer to Kinetics400 in [Data Preparation](/docs/data_preparation.md).

## Train

You can use the following command to train a model.
```shell
python tools/train.py ${CONFIG_FILE} [optional arguments]
```

Example: train ResNet model on Kinetics-400 audio dataset in a deterministic option with periodic validation.
```shell
python tools/train.py configs/audio_recognition/tsn_r50_64x1x1_100e_kinetics400_audio_feature.py \
--work-dir work_dirs/tsn_r50_64x1x1_100e_kinetics400_audio_feature \
--validate --seed 0 --deterministic
```

For more details, you can refer to **Training setting** part in [getting_started](/docs/getting_started.md#training-setting).

## Test

You can use the following command to test a model.
```shell
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]
```

Example: test ResNet model on Kinetics-400 audio dataset and dump the result to a json file.
```shell
python tools/test.py configs/audio_recognition/tsn_r50_64x1x1_100e_kinetics400_audio_feature.py \
checkpoints/SOME_CHECKPOINT.pth --eval top_k_accuracy mean_class_accuracy \
--out result.json
```

For more details, you can refer to **Test a dataset** part in [getting_started](/docs/getting_started.md#test-a-dataset).
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# model settings
model = dict(
type='AudioRecognizer',
backbone=dict(type='ResNet', depth=18, in_channels=1, norm_eval=False),
cls_head=dict(
type='AudioTSNHead',
num_classes=400,
in_channels=512,
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips='prob')
# dataset settings
dataset_type = 'AudioFeatureDataset'
data_root = 'data/kinetics400/audio_feature_train'
data_root_val = 'data/kinetics400/audio_feature_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_audio_feature.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_audio_feature.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_audio_feature.txt'
train_pipeline = [
dict(type='LoadAudioFeature'),
dict(type='SampleFrames', clip_len=64, frame_interval=1, num_clips=1),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
val_pipeline = [
dict(type='LoadAudioFeature'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
test_mode=True),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
test_pipeline = [
dict(type='LoadAudioFeature'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
test_mode=True),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
data = dict(
videos_per_gpu=320,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.1, momentum=0.9,
weight_decay=0.0001) # this lr is used for 8 gpus
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0)
total_epochs = 100
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_r18_64x1x1_100e_kinetics400_audio_feature/'
load_from = None
resume_from = None
workflow = [('train', 1)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# model settings
model = dict(
type='AudioRecognizer',
backbone=dict(type='ResNet', depth=50, in_channels=1, norm_eval=False),
cls_head=dict(
type='AudioTSNHead',
num_classes=400,
in_channels=2048,
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips='prob')
# dataset settings
dataset_type = 'AudioDataset'
data_root = 'data/kinetics400/audios'
data_root_val = 'data/kinetics400/audios'
ann_file_train = 'data/kinetics400/kinetics400_train_list_audio.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_audio.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_audio.txt'
train_pipeline = [
dict(type='AudioDecodeInit'),
dict(type='SampleFrames', clip_len=64, frame_interval=1, num_clips=1),
dict(type='AudioDecode'),
dict(type='AudioAmplify', ratio=1.5),
dict(type='MelLogSpectrogram'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
val_pipeline = [
dict(type='AudioDecodeInit'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
test_mode=True),
dict(type='AudioDecode'),
dict(type='AudioAmplify', ratio=1.5),
dict(type='MelLogSpectrogram'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
test_pipeline = [
dict(type='AudioDecodeInit'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
test_mode=True),
dict(type='AudioDecodeInit'),
dict(type='AudioAmplify', ratio=1.5),
dict(type='MelLogSpectrogram'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
data = dict(
videos_per_gpu=320,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.1, momentum=0.9,
weight_decay=0.0001) # this lr is used for 8 gpus
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0)
total_epochs = 100
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_r50_64x1x1_100e_kinetics400_audio/'
load_from = None
resume_from = None
workflow = [('train', 1)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# model settings
model = dict(
type='AudioRecognizer',
backbone=dict(type='ResNet', depth=50, in_channels=1, norm_eval=False),
cls_head=dict(
type='AudioTSNHead',
num_classes=400,
in_channels=2048,
dropout_ratio=0.5,
init_std=0.01))
# model training and testing settings
train_cfg = None
test_cfg = dict(average_clips='prob')
# dataset settings
dataset_type = 'AudioFeatureDataset'
data_root = 'data/kinetics400/audio_feature_train'
data_root_val = 'data/kinetics400/audio_feature_val'
ann_file_train = 'data/kinetics400/kinetics400_train_list_audio_feature.txt'
ann_file_val = 'data/kinetics400/kinetics400_val_list_audio_feature.txt'
ann_file_test = 'data/kinetics400/kinetics400_val_list_audio_feature.txt'
train_pipeline = [
dict(type='LoadAudioFeature'),
dict(type='SampleFrames', clip_len=64, frame_interval=1, num_clips=1),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
val_pipeline = [
dict(type='LoadAudioFeature'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
test_mode=True),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
test_pipeline = [
dict(type='LoadAudioFeature'),
dict(
type='SampleFrames',
clip_len=64,
frame_interval=1,
num_clips=1,
test_mode=True),
dict(type='AudioFeatureSelector'),
dict(type='FormatAudioShape', input_format='NCTF'),
dict(type='Collect', keys=['audios', 'label'], meta_keys=[]),
dict(type='ToTensor', keys=['audios'])
]
data = dict(
videos_per_gpu=320,
workers_per_gpu=4,
train=dict(
type=dataset_type,
ann_file=ann_file_train,
data_prefix=data_root,
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=ann_file_val,
data_prefix=data_root_val,
pipeline=val_pipeline),
test=dict(
type=dataset_type,
ann_file=ann_file_test,
data_prefix=data_root_val,
pipeline=test_pipeline))
# optimizer
optimizer = dict(
type='SGD', lr=0.1, momentum=0.9,
weight_decay=0.0001) # this lr is used for 8 gpus
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
lr_config = dict(policy='CosineAnnealing', min_lr=0)
total_epochs = 100
checkpoint_config = dict(interval=5)
evaluation = dict(
interval=5, metrics=['top_k_accuracy', 'mean_class_accuracy'], topk=(1, 5))
log_config = dict(
interval=20,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook'),
])
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/tsn_r50_64x1x1_100e_kinetics400_audio_feature/'
load_from = None
resume_from = None
workflow = [('train', 1)]
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

### v0.8.0 (master)

**Highlights**
- Support video recognition with audio modality

**Improvements**
- Set default values of 'average_clips' in each config file so that there is no need to set it explicitly during testing in most cases ([#232](https://github.com/open-mmlab/mmaction2/pull/232))
- Support data preparation for Kinetics-600 and Kinetics-700 ([#254](https://github.com/open-mmlab/mmaction2/pull/254))
Expand Down
Loading

0 comments on commit cfd257a

Please sign in to comment.