Skip to content

Commit

Permalink
Add example for ACM MM.
Browse files Browse the repository at this point in the history
  • Loading branch information
agkphysics committed Jun 9, 2023
1 parent b3f4f66 commit 6ac4a43
Show file tree
Hide file tree
Showing 58 changed files with 461 additions and 118 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ more information about the supported datasets and the required
processing.


# Examples
See the `examples` directory for examples.


# Papers
Papers that we have published will have associated code in the `papers`
directory. See [`papers/README.md`](papers/README.md) for more
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions conf/clf/tf/aldeneh2017/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
learning_rate: 1e-5
2 changes: 2 additions & 0 deletions conf/clf/tf/zhao2019/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
steps: 512
learning_rate: 1e-5
5 changes: 5 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREMA-D/
EMO-DB/
RAVDESS/
results/
logs/
30 changes: 30 additions & 0 deletions examples/CREMA-D.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: CREMA-D
data:
datasets:
CREMA-D:
path: CREMA-D/corpus.yaml
features: wav2vec_c_mean
model:
type: sk/lr
config: ${cwdpath:../conf/clf/sk/lr/default.yaml}
training:
normalise: online
transform: std
seq_transform: feature
tensorflow:
batch_size: 32
epochs: 50
logging:
log_dir: logs/tf
data_fn: null
pytorch:
batch_size: 32
epochs: 50
logging:
log_dir: logs/pt
eval:
cv:
part: speaker
kfold: 10
inner_kfold: 2
inner_part: speaker
30 changes: 30 additions & 0 deletions examples/EMO-DB.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: EMO-DB
data:
datasets:
EMO-DB:
path: EMO-DB/corpus.yaml
features: wav2vec_c_mean
model:
type: sk/lr
config: ${cwdpath:../conf/clf/sk/lr/default.yaml}
training:
normalise: online
transform: std
seq_transform: feature
tensorflow:
batch_size: 32
epochs: 50
logging:
log_dir: logs/tf
data_fn: null
pytorch:
batch_size: 32
epochs: 50
logging:
log_dir: logs/pt
eval:
cv:
part: speaker
kfold: -1
inner_kfold: 2
inner_part: speaker
31 changes: 31 additions & 0 deletions examples/RAVDESS.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: RAVDESS
data:
datasets:
RAVDESS:
path: RAVDESS/corpus.yaml
subset: speech
features: wav2vec_c_mean
model:
type: sk/lr
config: ${cwdpath:conf/clf/sk/lr/default.yaml}
training:
normalise: online
transform: std
seq_transform: feature
tensorflow:
batch_size: 32
epochs: 50
logging:
log_dir: logs/tf
data_fn: null
pytorch:
batch_size: 32
epochs: 50
logging:
log_dir: logs/pt
eval:
cv:
part: speaker
kfold: 8
inner_kfold: 2
inner_part: speaker
16 changes: 16 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Examples

This is a basic within-corpus and cross-corpus experimental setup.

## Datasets
Run `ertk-dataset setup` for each of CREMA-D, RAVDESS, EMO-DB:
```
ertk-dataset setup CREMA-D /path/to/CREMA-D ./CREMA-D
```

## Features
Run the `extract_features.sh` script to extract eGeMAPS, Wav2vec, and
log mel spectrogram features.

## Experiments
Run the `run_exps.sh` script to run experiments.
1 change: 1 addition & 0 deletions examples/aldeneh2017.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
learning_rate: 1e-5
43 changes: 43 additions & 0 deletions examples/exp_loco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
data:
datasets:
CREMA-D:
path: CREMA-D/corpus.yaml
EMO-DB:
path: EMO-DB/corpus.yaml
RAVDESS:
path: RAVDESS/corpus.yaml
subset: speech
remove_groups:
label:
keep:
- anger
- disgust
- fear
- happiness
- neutral
- sadness
features: # Will be set on command line
model: # Will be set on command line
type: _not_set_
config: {}
training:
normalise: online
transform: std
seq_transform: feature
tensorflow:
batch_size: 32
epochs: 50
logging:
log_dir: logs/tf
data_fn: null
pytorch:
batch_size: 32
epochs: 50
logging:
log_dir: logs/pt
eval:
cv:
part: corpus
kfold: -1
inner_kfold: 2
results: "" # Will be set on command line
34 changes: 34 additions & 0 deletions examples/extract_features.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

for dataset in CREMA-D EMO-DB RAVDESS; do
ertk-dataset process \
--processor opensmile \
--n_jobs -1 \
--sample_rate 16000 \
--corpus $dataset \
$dataset/files_all.txt \
$dataset/features/eGeMAPS.nc \
opensmile_config=eGeMAPS
ertk-dataset process \
--processor fairseq \
--sample_rate 16000 \
--corpus $dataset \
$dataset/files_all.txt \
$dataset/features/wav2vec_c_mean.nc \
model_type=wav2vec \
checkpoint=/path/to/wav2vec_large.pt \
layer=context \
aggregate=MEAN
ertk-dataset process \
--processor spectrogram \
--n_jobs -1 \
--sample_rate 16000 \
--corpus $dataset \
$dataset/files_all.txt \
$dataset/features/logmel-0.05-0.025-80.nc \
kind=mel \
window_size=0.05 \
window_shift=0.025 \
n_mels=80 \
to_log=log
done
27 changes: 27 additions & 0 deletions examples/run_exps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
export TF_CPP_MIN_LOG_LEVEL=1

# Within-corpus
for dataset in CREMA-D EMO-DB RAVDESS; do
ertk-cli exp2 ${dataset}.yaml data.features=eGeMAPS model.type=sk/lr model.param_grid=\${cwdpath:../conf/clf/sk/lr/grids/default.yaml} results=results_within/lr/eGeMAPS.csv
ertk-cli exp2 ${dataset}.yaml data.features=wav2vec_c_mean model.type=sk/lr model.param_grid=\${cwdpath:../conf/clf/sk/lr/grids/default.yaml} results=results_within/lr/wav2vec.csv
ertk-cli exp2 ${dataset}.yaml data.features=eGeMAPS model.type=sk/svm model.param_grid=\${cwdpath:../conf/clf/sk/svm/grids/rbf.yaml} results=results_within/svm/eGeMAPS.csv
ertk-cli exp2 ${dataset}.yaml data.features=wav2vec_c_mean model.type=sk/svm model.param_grid=\${cwdpath:../conf/clf/sk/svm/grids/rbf.yaml} results=results_within/svm/wav2vec.csv

ertk-cli exp2 exp_loco.yaml training.normalise=none data.features=logmel-0.05-0.025-80 data.pad_seq=100 data.clip_seq=512 model.type=tf/aldeneh2017 model.config=\${cwdpath:aldeneh2017.yaml} results=results_within/aldeneh2017/melspec.csv
ertk-cli exp2 exp_loco.yaml training.normalise=none data.features=logmel-0.05-0.025-80 data.pad_seq=256 data.clip_seq=256 model.type=tf/zhao2019 model.config=\${cwdpath:zhao2019.yaml} results=results_within/zhao2019/melspec.csv
done

# Cross-corpus

# SVM and logistic regression experiments
ertk-cli exp2 exp_loco.yaml data.features=eGeMAPS model.type=sk/lr model.param_grid=\${cwdpath:../conf/clf/sk/lr/grids/default.yaml} results=results_cross/lr/eGeMAPS.csv
ertk-cli exp2 exp_loco.yaml data.features=wav2vec_c_mean model.type=sk/lr model.param_grid=\${cwdpath:../conf/clf/sk/lr/grids/default.yaml} results=results_cross/lr/wav2vec.csv
ertk-cli exp2 exp_loco.yaml data.features=eGeMAPS model.type=sk/svm model.param_grid=\${cwdpath:../conf/clf/sk/svm/grids/rbf.yaml} results=results_cross/svm/eGeMAPS.csv
ertk-cli exp2 exp_loco.yaml data.features=wav2vec_c_mean model.type=sk/svm model.param_grid=\${cwdpath:../conf/clf/sk/svm/grids/rbf.yaml} results=results_cross/svm/wav2vec.csv

# Sequence models
ertk-cli exp2 exp_loco.yaml training.normalise=none data.features=logmel-0.05-0.025-80 data.pad_seq=100 data.clip_seq=512 model.type=tf/aldeneh2017 model.config=\${cwdpath:aldeneh2017.yaml} results=results_cross/aldeneh2017/melspec.csv
ertk-cli exp2 exp_loco.yaml training.normalise=none data.features=logmel-0.05-0.025-80 data.pad_seq=256 data.clip_seq=256 model.type=tf/zhao2019 model.config=\${cwdpath:zhao2019.yaml} results=results_cross/zhao2019/melspec.csv

# ertk-cli exp2 exp_loco.yaml training.normalise=none data.features=logmel-0.05-0.025-80 data.pad_seq=100 data.clip_seq=512 model.type=pt/aldeneh2017 model.config=\${cwdpath:../conf/clf/pt/aldeneh2017/default.yaml}
2 changes: 2 additions & 0 deletions examples/zhao2019.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
steps: 256
learning_rate: 1e-5
38 changes: 22 additions & 16 deletions src/ertk/cli/cli/exp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from ertk.config import get_arg_mapping
from ertk.dataset import load_datasets_config
from ertk.sklearn.utils import GridSearchVal
from ertk.train import ExperimentConfig, ValidationSplit, get_cv_splitter
from ertk.train import (
ExperimentConfig,
TransformClass,
ValidationSplit,
get_cv_splitter,
)
from ertk.transform import SequenceTransformWrapper


Expand All @@ -39,15 +44,20 @@ def main(config_path: Path, restargs: Tuple[str], verbose: int):

dataset = load_datasets_config(config.data)

transform = config.training.transform.name
transformer = {"std": StandardScaler, "minmax": MinMaxScaler}[transform]()
normalise = config.training.normalise
if normalise == "none":
transformer = None
elif normalise == "online" and len(dataset.x[0].shape) > 1:
transform = config.training.transform
if isinstance(transform, str):
transform = TransformClass[transform]
transformer = {
TransformClass.std: StandardScaler,
TransformClass.minmax: MinMaxScaler,
}[transform]()
if len(dataset.x[0].shape) > 1:
transformer = SequenceTransformWrapper(
transformer, config.training.seq_transform
)
normalise = config.training.normalise
if normalise == "none":
transformer = None
elif normalise != "online":
dataset.normalise(normaliser=transformer, partition=normalise)
transformer = None
Expand Down Expand Up @@ -142,9 +152,7 @@ def main(config_path: Path, restargs: Tuple[str], verbose: int):
fit_params["groups"] = dataset.get_group_indices(evaluation.inner_part)
elif clf_lib == "tf":
from keras.callbacks import TensorBoard
from keras.optimizers import get as get_optimizer

from ertk.tensorflow.classification import tf_classification_metrics
from ertk.tensorflow.models import TFModelConfig, get_tf_model
from ertk.tensorflow.train import TFTrainConfig

Expand Down Expand Up @@ -181,11 +189,6 @@ def main(config_path: Path, restargs: Tuple[str], verbose: int):

def model_fn():
model = get_tf_model(clf_type, **model_config)
model.compile(
get_optimizer("adam", learning_rate=model_config.learning_rate),
loss="sparse_categorical_crossentropy",
weighted_metrics=tf_classification_metrics(),
)
return Pipeline([("transform", transformer), ("clf", model)])

params = {
Expand All @@ -209,8 +212,11 @@ def model_fn():
n_jobs = 1
fit_params.update({"train_config": pt_config})

model_cls = ERTKPyTorchModel.get_model_class(clf_type)
model_config = model_cls.get_config_type().from_config(config.model.config)
model_config = (
ERTKPyTorchModel.get_model_class(clf_type)
.get_config_type()
.from_config(config.model.config)
)
if model_config.n_features == -1:
model_config.n_features = dataset.n_features
elif model_config.n_features != dataset.n_features:
Expand Down
12 changes: 6 additions & 6 deletions src/ertk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@
T = TypeVar("T", bound="ERTKConfig")


def resolve_files(key):
def _resolve_files(key): # pragma: no cover
return key


def resolve_files_load(key):
def _resolve_files_load(key): # pragma: no cover
return OmegaConf.load(key)


OmegaConf.register_resolver("file", resolve_files)
OmegaConf.register_resolver("cwdpath", resolve_files_load)
OmegaConf.register_resolver("file", _resolve_files)
OmegaConf.register_resolver("cwdpath", _resolve_files_load)


@dataclass
Expand Down Expand Up @@ -144,7 +144,7 @@ def merge_with_args(self: T, args: Optional[List[str]] = None) -> T:
return cast(T, OmegaConf.merge(self, OmegaConf.from_cli(args)))


def get_arg_mapping(s: Union[Path, str]) -> Dict[str, Any]:
def get_arg_mapping(s: Union[Path, str]) -> Dict[str, str]:
"""Given a mapping on the command-line, returns a dict representing
that mapping. Mapping can be a string or a more complex YAML file.
Expand All @@ -160,7 +160,7 @@ def get_arg_mapping(s: Union[Path, str]) -> Dict[str, Any]:
Returns
-------
mapping: dict
dict
A dictionary mapping keys to values from the string.
"""
if isinstance(s, Path) or Path(s).exists():
Expand Down
2 changes: 1 addition & 1 deletion src/ertk/dataset/predefined/AESDD/corpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ subsets:
clips: files_all
description: All files
default_subset: all
features_dir: ../../features/AESDD
features_dir: features
2 changes: 1 addition & 1 deletion src/ertk/dataset/predefined/ASED/corpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ subsets:
clips: files_all
description: All files
default_subset: all
features_dir: ../../features/ASED
features_dir: features
2 changes: 1 addition & 1 deletion src/ertk/dataset/predefined/BAVED/corpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ subsets:
clips: files_all
description: All files
default_subset: all
features_dir: ../../features/BAVED
features_dir: features
2 changes: 1 addition & 1 deletion src/ertk/dataset/predefined/CMU-MOSEI/corpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ subsets:
clips: files_labels
description: Clips with emotion and sentiment information
default_subset: labels
features_dir: ../../features/CMU-MOSEI
features_dir: features
2 changes: 1 addition & 1 deletion src/ertk/dataset/predefined/CREMA-D/corpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ subsets:
clips: files_all
description: All files
default_subset: all
features_dir: ../../features/CREMA-D
features_dir: features
2 changes: 1 addition & 1 deletion src/ertk/dataset/predefined/CaFE/corpus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ subsets:
clips: files_all
description: All files
default_subset: all
features_dir: ../../features/CaFE
features_dir: features
Loading

0 comments on commit 6ac4a43

Please sign in to comment.