Skip to content

Commit

Permalink
✏️ Fix minor issues for asr
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Feb 5, 2024
1 parent 3f41b11 commit 3eb23d5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions hezar/data/data_collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ def __init__(
self,
feature_extractor: AudioFeatureExtractor,
tokenizer: Tokenizer,
inputs_padding_type: str = "longest",
inputs_padding_type: str = None,
inputs_max_length: int = None,
labels_padding_type: str = "longest",
labels_padding_type: str = None,
labels_max_length: int = None,
):
self.feature_extractor = feature_extractor
Expand Down
8 changes: 4 additions & 4 deletions hezar/data/datasets/speech_recognition_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass

from datasets import Audio, load_dataset
from datasets import Audio, load_dataset, load_from_disk

from .dataset import Dataset
from ..data_collators import SpeechRecognitionDataCollator
Expand Down Expand Up @@ -44,14 +44,14 @@ def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, **kwargs)
self.data_collator = SpeechRecognitionDataCollator(
self.feature_extractor,
self.tokenizer,
inputs_padding_type="max_length" if self.config.audio_array_padding_type is not None else "longest",
inputs_padding_type="max_length" if self.config.max_audio_array_length is not None else "longest",
inputs_max_length=self.config.max_audio_array_length,
labels_padding_type=self.config.labels_padding_type,
labels_padding_type="max_length" if self.config.labels_max_length is not None else "longest",
labels_max_length=self.config.labels_max_length,
)

def _load(self, split):
data = load_dataset(self.config.path, split=split)
data = load_from_disk(self.config.path)[split]
data = data.cast_column(self.config.audio_column, Audio(sampling_rate=self.config.sampling_rate))
return data

Expand Down

0 comments on commit 3eb23d5

Please sign in to comment.