Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add speech recognition training & other improvements #147

Merged
merged 36 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
f5b6ab4
:sparkles: Implement metrics handler for speech recognition
arxyzan Jan 29, 2024
fe7ff00
:bento: Add `train_speech_recognition.py`
arxyzan Jan 29, 2024
63af54b
Merge branch 'main' into asr-training
arxyzan Feb 3, 2024
7a9a7b6
:bug: Fix `resolve_inputs_length_for_padding` bug
arxyzan Feb 4, 2024
e621612
:pencil2: Minor
arxyzan Feb 4, 2024
fa3c006
:pencil2: Minor
arxyzan Feb 4, 2024
4672331
:bug: Fix fields bug in `Config`
arxyzan Feb 4, 2024
f507b11
:bug: Fix bug in `AudioFeatureExtractor.pad()`
arxyzan Feb 4, 2024
56632f9
:sparkles: Add `SpeechRecognitionDataCollator`
arxyzan Feb 4, 2024
7068c2f
:bug: Fix wrong attributes in `WhisperBPEConfig`
arxyzan Feb 4, 2024
bbbf902
:bug: Handle bugs in `WhisperSpeechRecognition`
arxyzan Feb 4, 2024
35a16a4
:sparkles: Add dataset loading script for ASR
arxyzan Feb 5, 2024
d505b09
:sparkles: Add `SpeechRecognitionDataset`
arxyzan Feb 5, 2024
6b1a7c7
:bug: Fix tokenizer `max_length` bug
arxyzan Feb 5, 2024
f3b6988
:pencil2: Improve logging robustness in `Trainer`
arxyzan Feb 5, 2024
bc30588
:pencil2: Update `train_speech_recognition.py`
arxyzan Feb 5, 2024
c4dd743
:pencil2: Rename `padding` -> `padding_type` in `data_utils.resolve_i…
arxyzan Feb 5, 2024
c95c4f5
:test_tube: Add `speech_recognition` to tests for datasets and trainer
arxyzan Feb 5, 2024
19574b6
:test_tube: Improve flexibility of tests in `test_datasets.py`
arxyzan Feb 5, 2024
3f41b11
:test_tube: Ignore errors for `rmtree` in `test_trainer.py`
arxyzan Feb 5, 2024
3eb23d5
:pencil2: Fix minor issues for asr
arxyzan Feb 5, 2024
fae959d
:test_tube: Limit max input lengths to prevent crash in CI
arxyzan Feb 5, 2024
f3c28d1
:sparkles: Add `clean_cache` function to utils
arxyzan Feb 5, 2024
e8eaa29
:test_tube: Clean cache after every train process
arxyzan Feb 5, 2024
b37cd26
:test_tube: Add `CI_MODE` to `tests.yml`
arxyzan Feb 5, 2024
cdea95d
:bug: Fix minor bug in `speech_recognition_dataset.py`
arxyzan Feb 5, 2024
d9a298e
:test_tube: Minor change
arxyzan Feb 5, 2024
045b820
:test_tube: Clean cache in `test_datasets.py`
arxyzan Feb 5, 2024
d00c30a
:test_tube: Minor change
arxyzan Feb 5, 2024
ac6a8db
:test_tube: Clean cache after every test
arxyzan Feb 6, 2024
0e241ee
:test_tube: Limit sizes in `test_trainer.py`
arxyzan Feb 6, 2024
3cd1275
:test_tube: Limit sizes in `test_trainer.py`
arxyzan Feb 6, 2024
fc23537
:ambulance: Fix a silent critical bug in `Tokenizer.__call__`
arxyzan Feb 7, 2024
ea8ef2a
:test_tube: Minor renamings
arxyzan Feb 7, 2024
ae61af3
:test_tube: Minor
arxyzan Feb 7, 2024
fb77c6b
:bug: Fix wrong cache dir in `speech_recognition_dataset.py`
arxyzan Feb 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ jobs:
pip install pytest
- name: Run pytest
run: |
pytest tests/
CI_MODE=TRUE pytest -v tests/
36 changes: 36 additions & 0 deletions examples/train/train_speech_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from hezar.models import Model
from hezar.data import SpeechRecognitionDataset, SpeechRecognitionDatasetConfig
from hezar.trainer import Trainer, TrainerConfig


dataset_path = "hezarai/common-voice-13-fa"
base_model_path = "hezarai/whisper-small"

dataset_config = SpeechRecognitionDatasetConfig(
path=dataset_path,
feature_extractor_path=base_model_path,
tokenizer_path=base_model_path,
)
train_dataset = SpeechRecognitionDataset(dataset_config, split="train", labels_max_length=64)
eval_dataset = SpeechRecognitionDataset(dataset_config, split="test[:50%]", labels_max_length=64)
model = Model.load(base_model_path)

train_config = TrainerConfig(
output_dir="whisper-small-fa-commonvoice",
task="speech_recognition",
init_weights_from=base_model_path,
mixed_precision="bf16",
gradient_accumulation_steps=8,
batch_size=4,
num_epochs=5,
metrics=["cer", "wer"],
)

trainer = Trainer(
config=train_config,
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=train_dataset.data_collator,
)
trainer.train()
8 changes: 6 additions & 2 deletions hezar/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def __len__(self):
def __iter__(self):
return iter(self.dict())

@classmethod
def fields(cls):
return cls.__dataclass_fields__

def dict(self):
"""
Returns the config object as a dictionary (works on nested dataclasses too)
Expand Down Expand Up @@ -137,7 +141,7 @@ def update(self, d: dict, **kwargs):
"""
d.update(kwargs)
for k, v in d.items():
if k not in self.__annotations__.keys():
if k not in self.fields():
logger.warning(f"`{str(self.__class__.__name__)}` does not take `{k}` as a config parameter!")
setattr(self, k, v)
return self
Expand Down Expand Up @@ -215,7 +219,7 @@ def from_dict(cls, dict_config: Dict | DictConfig, **kwargs):
if config_cls is not None:
dict_config[k] = config_cls.from_dict(v)

dict_config = {k: v for k, v in dict_config.items() if k in cls.__annotations__ and k not in CONFIG_CLASS_VARS}
dict_config = {k: v for k, v in dict_config.items() if k in cls.fields() and cls.fields()[k].init}

config = cls(**dict_config) # noqa

Expand Down
45 changes: 44 additions & 1 deletion hezar/data/data_collators.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import numpy as np
import torch

from ..preprocessors import Tokenizer
from ..preprocessors import Tokenizer, AudioFeatureExtractor
from ..utils import Logger, convert_batch_dict_dtype


__all__ = [
"TextPaddingDataCollator",
"TextGenerationDataCollator",
"ImageCaptioningDataCollator",
"SpeechRecognitionDataCollator",
"SequenceLabelingDataCollator",
"CharLevelOCRDataCollator",
]
Expand Down Expand Up @@ -223,6 +224,48 @@ def __call__(self, encoded_batch):
return padded_batch


class SpeechRecognitionDataCollator:
def __init__(
self,
feature_extractor: AudioFeatureExtractor,
tokenizer: Tokenizer,
inputs_padding_type: str = None,
inputs_max_length: int = None,
labels_padding_type: str = None,
labels_max_length: int = None,
):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.inputs_padding_type = inputs_padding_type
self.inputs_max_length = inputs_max_length
self.labels_padding_type = labels_padding_type
self.labels_max_length = labels_max_length

def __call__(self, input_batch):
input_batch = [convert_batch_dict_dtype(x, dtype="list") for x in input_batch]
inputs = {}
for key in input_batch[0].keys():
stack = [e for item in input_batch for e in item[key]]
inputs[key] = stack

inputs = self.tokenizer.pad_encoded_batch(
inputs,
padding=self.labels_padding_type,
max_length=self.labels_max_length,
exclude_keys=["input_features"],
return_tensors="pt"
)

inputs = self.feature_extractor.pad(
inputs,
padding=self.inputs_padding_type,
max_length=self.inputs_max_length,
return_tensors="pt",
)

return inputs


class SequenceLabelingDataCollator:
"""
A data collator for sequence labeling.
Expand Down
1 change: 1 addition & 0 deletions hezar/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .image_captioning_dataset import ImageCaptioningDataset, ImageCaptioningDatasetConfig
from .ocr_dataset import OCRDataset, OCRDatasetConfig
from .sequence_labeling_dataset import SequenceLabelingDataset, SequenceLabelingDatasetConfig
from .speech_recognition_dataset import SpeechRecognitionDataset, SpeechRecognitionDatasetConfig
from .text_classification_dataset import TextClassificationDataset, TextClassificationDatasetConfig
from .text_summarization_dataset import TextSummarizationDataset, TextSummarizationDatasetConfig
8 changes: 0 additions & 8 deletions hezar/data/datasets/image_captioning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ class ImageCaptioningDatasetConfig(DatasetConfig):

@register_dataset("image_captioning", config_class=ImageCaptioningDatasetConfig)
class ImageCaptioningDataset(Dataset):
"""
General image captioning dataset class.

Args:
config (ImageCaptioningDatasetConfig): The configuration object for the dataset.
split: Dataset split, defaults to None.
**kwargs: Additional keyword arguments.
"""
required_backends = _required_backends

def __init__(self, config: ImageCaptioningDatasetConfig, split=None, **kwargs):
Expand Down
81 changes: 81 additions & 0 deletions hezar/data/datasets/speech_recognition_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from dataclasses import dataclass

from datasets import Audio, load_dataset

from .dataset import Dataset
from ..data_collators import SpeechRecognitionDataCollator
from ...configs import DatasetConfig
from ...constants import TaskType, Backends, PaddingType
from ...preprocessors import Tokenizer, AudioFeatureExtractor
from ...registry import register_dataset

_required_backends = [Backends.LIBROSA, Backends.DATASETS]


@dataclass
class SpeechRecognitionDatasetConfig(DatasetConfig):
name = "speech_recognition"
task = TaskType.SPEECH_RECOGNITION
path: str = None
feature_extractor_path: str = None
tokenizer_path: str = None
sampling_rate: int = 16000
audio_array_padding_type: bool | str | PaddingType = None
max_audio_array_length: int = None
labels_padding_type: bool | str | PaddingType = None
labels_max_length: int = None
audio_file_path_column: str = "path"
audio_column: str = "audio"
audio_array_column: str = "array"
transcript_column: str = "sentence"


@register_dataset("speech_recognition", config_class=SpeechRecognitionDatasetConfig)
class SpeechRecognitionDataset(Dataset):
required_backends = _required_backends

def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, **kwargs):
super().__init__(config, split, **kwargs)
self.data = self._load(split)
self.feature_extractor = AudioFeatureExtractor.load(self.config.feature_extractor_path)
self.tokenizer = Tokenizer.load(self.config.tokenizer_path)
self.data_collator = SpeechRecognitionDataCollator(
self.feature_extractor,
self.tokenizer,
inputs_padding_type="max_length" if self.config.max_audio_array_length is not None else "longest",
inputs_max_length=self.config.max_audio_array_length,
labels_padding_type="max_length" if self.config.labels_max_length is not None else "longest",
labels_max_length=self.config.labels_max_length,
)

def _load(self, split):
data = load_dataset(self.config.path, split=split, cache_dir=self.cache_dir)
data = data.cast_column(self.config.audio_column, Audio(sampling_rate=self.config.sampling_rate))
return data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
sample_dict = self.data[index]
transcript = sample_dict[self.config.transcript_column]
audio_array = sample_dict[self.config.audio_column][self.config.audio_array_column]

input_features = self.feature_extractor(
audio_array,
sampling_rate=self.config.sampling_rate,
return_tensors="pt"
)["input_features"]

labels = self.tokenizer(
transcript,
max_length=self.config.labels_max_length,
return_tensors="pt"
)["token_ids"]

return {
"input_features": input_features,
"labels": labels
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import copy
from typing import List

import numpy as np
import torch

from ....constants import Backends
from ....registry import register_model
from ....utils import is_backend_available, load_audio_files
from ....utils import is_backend_available, load_audio_files, shift_tokens_right
from ...model import Model
from ...model_outputs import SpeechRecognitionOutput
from .whisper_speech_recognition_config import WhisperSpeechRecognitionConfig
Expand Down Expand Up @@ -52,11 +53,16 @@ def forward(
encoder_outputs=None,
past_key_values=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
**kwargs,
):
if decoder_input_ids is None or decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.whisper(
input_features=input_features,
attention_mask=attention_mask,
Expand All @@ -68,15 +74,16 @@ def forward(
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

return outputs
return dict(outputs)

def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
labels = copy.deepcopy(labels)
labels[labels == self.config.pad_token_id] = -100
loss = self.criterion(logits.view(-1, self.config.vocab_size), labels.view(-1))
return loss

Expand Down
12 changes: 10 additions & 2 deletions hezar/models/speech_recognition/whisper/whisper_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,28 @@
@dataclass
class WhisperBPEConfig(BPEConfig):
name = "whisper_bpe_tokenizer"
max_length: int = 448
truncation_strategy: str = "longest_first"
truncation_direction: str = "right"
stride: int = 0
padding_strategy: str = "longest"
padding_direction: str = "right"
pad_to_multiple_of: int = 0
pad_token: str = "<|endoftext|>"
unk_token: str = "<|endoftext|>"
bos_token: str = "<|startoftranscript|>"
bos_token: str = "<|endoftext|>"
eos_token: str = "<|endoftext|>"
translate_token: str = "<|translate|>"
transcribe_token: str = "<|transcribe|>"
notimestamps_token: str = "<|notimestamps|>"
additional_special_tokens: List = field(default_factory=lambda: ADDITIONAL_SPECIAL_TOKENS)
padding_direction: str = "right"
add_prefix_space: bool = False
add_bos_token: bool = False
model_max_length: int = 1024
language: str = None
task: str = None
predict_timestamps: str = False
show_progress: bool = True


@register_preprocessor("whisper_bpe_tokenizer", config_class=WhisperBPEConfig)
Expand Down
10 changes: 7 additions & 3 deletions hezar/preprocessors/audio_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Mapping
from typing import List, Mapping

import numpy as np

from ..builders import build_preprocessor
from ..configs import PreprocessorConfig
from ..constants import DEFAULT_FEATURE_EXTRACTOR_CONFIG_FILE
from ..constants import DEFAULT_FEATURE_EXTRACTOR_CONFIG_FILE, PaddingType
from ..utils import convert_batch_dict_dtype
from .preprocessor import Preprocessor

Expand Down Expand Up @@ -36,7 +38,7 @@ def __call__(self, inputs, **kwargs):
def pad(
self,
processed_features,
padding=None,
padding: bool | str | PaddingType = True,
max_length=None,
truncation=None,
pad_to_multiple_of=None,
Expand All @@ -58,6 +60,7 @@ def pad(
"""
return_attention_mask = return_attention_mask or self.config.return_attention_mask
padding = padding or self.config.padding

if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], Mapping):
processed_features = {
key: np.array([example[key] for example in processed_features]) for key in processed_features[0].keys()
Expand Down Expand Up @@ -121,6 +124,7 @@ def pad(
return_attention_mask=return_attention_mask,
)
for key, value in outputs.items():
value = np.array(value)
if key not in batch_outputs:
batch_outputs[key] = []
if value.dtype is np.dtype(np.float64):
Expand Down
Loading
Loading