Skip to content

Custom 🤗 Transformers for training multi-task wav2vec2 models that perform ASR and speech classification tasks simultaneously as described in Getman, Y., Al-Ghezi, R., Grósz, T., Kurimo, M. (2023) Multi-task wav2vec2 Serving as a Pronunciation Training System for Children.

License

Notifications You must be signed in to change notification settings

aalto-speech/multitask-wav2vec2

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multi-task wav2vec2

Custom 🤗 Transformers for training multi-task wav2vec2 models that perform ASR and speech classification tasks simultaneously as described in Getman, Y., Al-Ghezi, R., Grósz, T., Kurimo, M. (2023) Multi-task wav2vec2 Serving as a Pronunciation Training System for Children.

The best multi-task wav2vec2 models presented in the article are available at 🤗 Hub, see GetmanY1/wav2vec2-large-multitask-swedish-ssd and GetmanY1/wav2vec2-large-multitask-finnish-l2

Multi-task wav2vec2 system overview

Alt text

Usage

You must first install this fork to use the multi-task wav2vec2 models or train your own one:

git clone https://github.com/aalto-speech/multitask-wav2vec2
cd multitask-wav2vec2
pip install -e .

The model can then be used directly as follows:

import torch
import librosa
import datasets
from transformers import Wav2Vec2ForMultiTask, Wav2Vec2Processor

def map_to_array(batch):
    speech, _ = librosa.load(batch["file"], sr=16000, mono=True)
    batch["speech"] = speech
    return batch

def map_to_pred_multitask(batch):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    input_values = processor(batch["speech"], sampling_rate=16000, return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to(device)).logits
    predicted_ids_ctc = torch.argmax(logits[1], dim=-1)
    transcription = processor.batch_decode(predicted_ids_ctc)
    batch["transcription"] = transcription
    predicted_ids = torch.argmax(logits[0], dim=-1)
    batch['predictions'] = predicted_ids
    return batch

processor =  Wav2Vec2Processor.from_pretrained(MODEL_PATH)
model = Wav2Vec2ForMultiTask.from_pretrained(MODEL_PATH)

test_dataset = test_dataset.map(map_to_array)
result = test_dataset.map(map_to_pred_multitask)

Training a multi-task wav2vec2 model

To train your own multi-task wav2vec2 model, see examples/pytorch/multitask

Citation

If you use our models or training scripts, please cite our article as:

@inproceedings{getman23_slate,
  author={Yaroslav Getman and Ragheb Al-Ghezi and Tamas Grosz and Mikko Kurimo},
  title={{Multi-task wav2vec2 Serving as a Pronunciation Training System for Children}},
  year=2023,
  booktitle={Proc. 9th Workshop on Speech and Language Technology in Education (SLaTE)},
  pages={36--40},
  doi={10.21437/SLaTE.2023-8}
}

About

Custom 🤗 Transformers for training multi-task wav2vec2 models that perform ASR and speech classification tasks simultaneously as described in Getman, Y., Al-Ghezi, R., Grósz, T., Kurimo, M. (2023) Multi-task wav2vec2 Serving as a Pronunciation Training System for Children.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.5%
  • Cuda 0.4%
  • Shell 0.1%
  • Dockerfile 0.0%
  • C++ 0.0%
  • C 0.0%