Skip to content

Commit

Permalink
Create factory for audio to embedding pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Aug 26, 2024
1 parent a461e08 commit ef4581c
Showing 1 changed file with 64 additions and 44 deletions.
108 changes: 64 additions & 44 deletions huggingface_pipelines/audio.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import logging
from dataclasses import dataclass
from typing import Any, Dict, List

import numpy as np
import torch

from typing import Dict, Any
from dataclasses import dataclass
from sonar.inference_pipelines.speech import SpeechToEmbeddingModelPipeline
import logging
from .pipeline import Pipeline, PipelineConfig
from .dataset import DatasetConfig
import numpy as np

from .dataset import DatasetConfig # type: ignore
from .pipeline import Pipeline, PipelineConfig # type:ignore

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,7 +63,6 @@ class HFAudioToEmbeddingPipelineConfig(PipelineConfig):
)
"""

encoder_model: str = "text_sonar_basic_encoder"
fbank_dtype: torch.dtype = torch.float32
n_parallel: int = 4
Expand Down Expand Up @@ -108,7 +105,7 @@ def __init__(self, config: HFAudioToEmbeddingPipelineConfig):
self.model = SpeechToEmbeddingModelPipeline(
encoder=self.config.encoder_model,
device=torch.device(self.config.device),
fbank_dtype=self.config.fbank_dtype,
fbank_dtype=self.config.fbank_dtype
)

def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -134,7 +131,8 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
try:
for column in self.config.columns:
if column not in batch:
logger.warning(f"Column {column} not found in batch. Skipping.")
logger.warning(
f"Column {column} not found in batch. Skipping.")
continue

audio_inputs = []
Expand All @@ -145,13 +143,9 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
audio_data_list = [audio_data_list]

for audio_data in audio_data_list:
if (
isinstance(audio_data, dict)
and "array" in audio_data
and "sampling_rate" in audio_data
):
if isinstance(audio_data, dict) and 'array' in audio_data and 'sampling_rate' in audio_data:
# Handle multi-channel audio by taking the mean across channels
audio_array = audio_data["array"]
audio_array = audio_data['array']
if audio_array.ndim > 1:
audio_array = np.mean(audio_array, axis=0)

Expand All @@ -163,60 +157,51 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_tensor.dim() > 2:
raise ValueError(
f"Unexpected audio tensor shape: {audio_tensor.shape}"
)
f"Unexpected audio tensor shape: {audio_tensor.shape}")

audio_inputs.append(audio_tensor)
else:
logger.warning(
f"Invalid audio data format in column {column}: {audio_data}"
)
f"Invalid audio data format in column {column}: {audio_data}")

if not audio_inputs:
logger.warning(f"No valid audio inputs found in column {column}.")
logger.warning(
f"No valid audio inputs found in column {column}.")
continue

try:
# Move tensors to the specified device
audio_inputs = [
tensor.to(self.config.device) for tensor in audio_inputs
]
audio_inputs = [tensor.to(self.config.device)
for tensor in audio_inputs]

all_embeddings: torch.Tensor = self.model.predict(
all_embeddings = self.model.predict(
input=audio_inputs,
batch_size=self.config.batch_size,
n_parallel=self.config.n_parallel,
pad_idx=self.config.pad_idx,
pad_idx=self.config.pad_idx
)

# Ensure all embeddings are 2D
processed_embeddings: List[torch.Tensor] = [
emb.unsqueeze(0) if emb.dim() == 1 else emb
for emb in all_embeddings
]
all_embeddings = [emb.unsqueeze(0) if emb.dim(
) == 1 else emb for emb in all_embeddings]

# Get the maximum sequence length
max_seq_len = max(emb.shape[0] for emb in processed_embeddings)
max_seq_len = max(emb.shape[0] for emb in all_embeddings)

# Pad embeddings to have the same sequence length
padded_embeddings = [
torch.nn.functional.pad(
emb, (0, 0, 0, max_seq_len - emb.shape[0])
)
for emb in processed_embeddings
]
padded_embeddings = [torch.nn.functional.pad(
emb, (0, 0, 0, max_seq_len - emb.shape[0])) for emb in all_embeddings]

# Stack embeddings into a single tensor
stacked_embeddings = torch.stack(padded_embeddings).unsqueeze(1)
stacked_embeddings = torch.stack(
padded_embeddings).unsqueeze(1)

batch[f"{column}_{self.config.output_column_suffix}"] = (
stacked_embeddings.cpu().numpy()
)
batch[f"{column}_{self.config.output_column_suffix}"] = stacked_embeddings.cpu(
).numpy()

except Exception as e:
logger.error(
f"Error in model.predict for column {column}: {str(e)}"
)
f"Error in model.predict for column {column}: {str(e)}")
# Instead of raising, we'll set the output to None and continue processing
batch[f"{column}_{self.config.output_column_suffix}"] = None

Expand All @@ -226,3 +211,38 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
# Instead of raising, we'll return the batch as is

return batch


class AudioToEmbeddingPipelineFactory:
"""
Factory class for creating AudioToEmbedding pipelines.
This factory creates HFAudioToEmbeddingPipeline instances based on the provided configuration.
Example:
factory = AudioToEmbeddingPipelineFactory()
config = {
"encoder_model": "sonar_speech_encoder_large",
"fbank_dtype": torch.float16,
"n_parallel": 4,
"pad_idx": 0,
"audio_column": "audio",
"device": "cuda",
"batch_size": 32,
"columns": ["audio"],
"output_path": "/path/to/output",
"output_column_suffix": "embedding"
}
pipeline = factory.create_pipeline(config)
"""

def create_pipeline(self, config: Dict[str, Any]) -> Pipeline:
"""
Create an AudioToEmbedding pipeline based on the provided configuration.
Returns:
Pipeline: An instance of HFAudioToEmbeddingPipeline.
"""
pipeline_config = HFAudioToEmbeddingPipelineConfig(**config)
return HFAudioToEmbeddingPipeline(pipeline_config)

0 comments on commit ef4581c

Please sign in to comment.