From ef4581c7df33f3234bbc8c930899ace1f59a8d18 Mon Sep 17 00:00:00 2001 From: Botir Khaltaev Date: Mon, 26 Aug 2024 15:44:40 +0100 Subject: [PATCH] Create factory for audio to embedding pipeline --- huggingface_pipelines/audio.py | 108 +++++++++++++++++++-------------- 1 file changed, 64 insertions(+), 44 deletions(-) diff --git a/huggingface_pipelines/audio.py b/huggingface_pipelines/audio.py index 170e85c..259e665 100644 --- a/huggingface_pipelines/audio.py +++ b/huggingface_pipelines/audio.py @@ -1,14 +1,12 @@ -import logging -from dataclasses import dataclass -from typing import Any, Dict, List - -import numpy as np import torch - +from typing import Dict, Any +from dataclasses import dataclass from sonar.inference_pipelines.speech import SpeechToEmbeddingModelPipeline +import logging +from .pipeline import Pipeline, PipelineConfig +from .dataset import DatasetConfig +import numpy as np -from .dataset import DatasetConfig # type: ignore -from .pipeline import Pipeline, PipelineConfig # type:ignore logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -65,7 +63,6 @@ class HFAudioToEmbeddingPipelineConfig(PipelineConfig): ) """ - encoder_model: str = "text_sonar_basic_encoder" fbank_dtype: torch.dtype = torch.float32 n_parallel: int = 4 @@ -108,7 +105,7 @@ def __init__(self, config: HFAudioToEmbeddingPipelineConfig): self.model = SpeechToEmbeddingModelPipeline( encoder=self.config.encoder_model, device=torch.device(self.config.device), - fbank_dtype=self.config.fbank_dtype, + fbank_dtype=self.config.fbank_dtype ) def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: @@ -134,7 +131,8 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: try: for column in self.config.columns: if column not in batch: - logger.warning(f"Column {column} not found in batch. Skipping.") + logger.warning( + f"Column {column} not found in batch. Skipping.") continue audio_inputs = [] @@ -145,13 +143,9 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: audio_data_list = [audio_data_list] for audio_data in audio_data_list: - if ( - isinstance(audio_data, dict) - and "array" in audio_data - and "sampling_rate" in audio_data - ): + if isinstance(audio_data, dict) and 'array' in audio_data and 'sampling_rate' in audio_data: # Handle multi-channel audio by taking the mean across channels - audio_array = audio_data["array"] + audio_array = audio_data['array'] if audio_array.ndim > 1: audio_array = np.mean(audio_array, axis=0) @@ -163,60 +157,51 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: audio_tensor = audio_tensor.unsqueeze(0) elif audio_tensor.dim() > 2: raise ValueError( - f"Unexpected audio tensor shape: {audio_tensor.shape}" - ) + f"Unexpected audio tensor shape: {audio_tensor.shape}") audio_inputs.append(audio_tensor) else: logger.warning( - f"Invalid audio data format in column {column}: {audio_data}" - ) + f"Invalid audio data format in column {column}: {audio_data}") if not audio_inputs: - logger.warning(f"No valid audio inputs found in column {column}.") + logger.warning( + f"No valid audio inputs found in column {column}.") continue try: # Move tensors to the specified device - audio_inputs = [ - tensor.to(self.config.device) for tensor in audio_inputs - ] + audio_inputs = [tensor.to(self.config.device) + for tensor in audio_inputs] - all_embeddings: torch.Tensor = self.model.predict( + all_embeddings = self.model.predict( input=audio_inputs, batch_size=self.config.batch_size, n_parallel=self.config.n_parallel, - pad_idx=self.config.pad_idx, + pad_idx=self.config.pad_idx ) # Ensure all embeddings are 2D - processed_embeddings: List[torch.Tensor] = [ - emb.unsqueeze(0) if emb.dim() == 1 else emb - for emb in all_embeddings - ] + all_embeddings = [emb.unsqueeze(0) if emb.dim( + ) == 1 else emb for emb in all_embeddings] # Get the maximum sequence length - max_seq_len = max(emb.shape[0] for emb in processed_embeddings) + max_seq_len = max(emb.shape[0] for emb in all_embeddings) # Pad embeddings to have the same sequence length - padded_embeddings = [ - torch.nn.functional.pad( - emb, (0, 0, 0, max_seq_len - emb.shape[0]) - ) - for emb in processed_embeddings - ] + padded_embeddings = [torch.nn.functional.pad( + emb, (0, 0, 0, max_seq_len - emb.shape[0])) for emb in all_embeddings] # Stack embeddings into a single tensor - stacked_embeddings = torch.stack(padded_embeddings).unsqueeze(1) + stacked_embeddings = torch.stack( + padded_embeddings).unsqueeze(1) - batch[f"{column}_{self.config.output_column_suffix}"] = ( - stacked_embeddings.cpu().numpy() - ) + batch[f"{column}_{self.config.output_column_suffix}"] = stacked_embeddings.cpu( + ).numpy() except Exception as e: logger.error( - f"Error in model.predict for column {column}: {str(e)}" - ) + f"Error in model.predict for column {column}: {str(e)}") # Instead of raising, we'll set the output to None and continue processing batch[f"{column}_{self.config.output_column_suffix}"] = None @@ -226,3 +211,38 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: # Instead of raising, we'll return the batch as is return batch + + +class AudioToEmbeddingPipelineFactory: + """ + Factory class for creating AudioToEmbedding pipelines. + + This factory creates HFAudioToEmbeddingPipeline instances based on the provided configuration. + + Example: + factory = AudioToEmbeddingPipelineFactory() + config = { + "encoder_model": "sonar_speech_encoder_large", + "fbank_dtype": torch.float16, + "n_parallel": 4, + "pad_idx": 0, + "audio_column": "audio", + "device": "cuda", + "batch_size": 32, + "columns": ["audio"], + "output_path": "/path/to/output", + "output_column_suffix": "embedding" + } + pipeline = factory.create_pipeline(config) + """ + + def create_pipeline(self, config: Dict[str, Any]) -> Pipeline: + """ + Create an AudioToEmbedding pipeline based on the provided configuration. + + Returns: + Pipeline: An instance of HFAudioToEmbeddingPipeline. + """ + pipeline_config = HFAudioToEmbeddingPipelineConfig(**config) + return HFAudioToEmbeddingPipeline(pipeline_config) +