diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index b6bc45a1..fc903ac2 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -229,9 +229,9 @@ def __init__( self, feature_extractor: AudioFeatureExtractor, tokenizer: Tokenizer, - inputs_padding_type: str = "longest", + inputs_padding_type: str = None, inputs_max_length: int = None, - labels_padding_type: str = "longest", + labels_padding_type: str = None, labels_max_length: int = None, ): self.feature_extractor = feature_extractor diff --git a/hezar/data/datasets/speech_recognition_dataset.py b/hezar/data/datasets/speech_recognition_dataset.py index fd69c926..082fd49e 100644 --- a/hezar/data/datasets/speech_recognition_dataset.py +++ b/hezar/data/datasets/speech_recognition_dataset.py @@ -2,7 +2,7 @@ from dataclasses import dataclass -from datasets import Audio, load_dataset +from datasets import Audio, load_dataset, load_from_disk from .dataset import Dataset from ..data_collators import SpeechRecognitionDataCollator @@ -44,14 +44,14 @@ def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, **kwargs) self.data_collator = SpeechRecognitionDataCollator( self.feature_extractor, self.tokenizer, - inputs_padding_type="max_length" if self.config.audio_array_padding_type is not None else "longest", + inputs_padding_type="max_length" if self.config.max_audio_array_length is not None else "longest", inputs_max_length=self.config.max_audio_array_length, - labels_padding_type=self.config.labels_padding_type, + labels_padding_type="max_length" if self.config.labels_max_length is not None else "longest", labels_max_length=self.config.labels_max_length, ) def _load(self, split): - data = load_dataset(self.config.path, split=split) + data = load_from_disk(self.config.path)[split] data = data.cast_column(self.config.audio_column, Audio(sampling_rate=self.config.sampling_rate)) return data