From 3eb23d5f2b4d2377e075976588a851d800466187 Mon Sep 17 00:00:00 2001 From: arxyzan Date: Mon, 5 Feb 2024 17:52:44 +0330 Subject: [PATCH] :pencil2: Fix minor issues for asr --- hezar/data/data_collators.py | 4 ++-- hezar/data/datasets/speech_recognition_dataset.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hezar/data/data_collators.py b/hezar/data/data_collators.py index b6bc45a1..fc903ac2 100644 --- a/hezar/data/data_collators.py +++ b/hezar/data/data_collators.py @@ -229,9 +229,9 @@ def __init__( self, feature_extractor: AudioFeatureExtractor, tokenizer: Tokenizer, - inputs_padding_type: str = "longest", + inputs_padding_type: str = None, inputs_max_length: int = None, - labels_padding_type: str = "longest", + labels_padding_type: str = None, labels_max_length: int = None, ): self.feature_extractor = feature_extractor diff --git a/hezar/data/datasets/speech_recognition_dataset.py b/hezar/data/datasets/speech_recognition_dataset.py index fd69c926..082fd49e 100644 --- a/hezar/data/datasets/speech_recognition_dataset.py +++ b/hezar/data/datasets/speech_recognition_dataset.py @@ -2,7 +2,7 @@ from dataclasses import dataclass -from datasets import Audio, load_dataset +from datasets import Audio, load_dataset, load_from_disk from .dataset import Dataset from ..data_collators import SpeechRecognitionDataCollator @@ -44,14 +44,14 @@ def __init__(self, config: SpeechRecognitionDatasetConfig, split=None, **kwargs) self.data_collator = SpeechRecognitionDataCollator( self.feature_extractor, self.tokenizer, - inputs_padding_type="max_length" if self.config.audio_array_padding_type is not None else "longest", + inputs_padding_type="max_length" if self.config.max_audio_array_length is not None else "longest", inputs_max_length=self.config.max_audio_array_length, - labels_padding_type=self.config.labels_padding_type, + labels_padding_type="max_length" if self.config.labels_max_length is not None else "longest", labels_max_length=self.config.labels_max_length, ) def _load(self, split): - data = load_dataset(self.config.path, split=split) + data = load_from_disk(self.config.path)[split] data = data.cast_column(self.config.audio_column, Audio(sampling_rate=self.config.sampling_rate)) return data