Skip to content

Commit

Permalink
Fix black linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Aug 23, 2024
1 parent b0ea257 commit a461e08
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 48 deletions.
58 changes: 35 additions & 23 deletions huggingface_pipelines/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class HFAudioToEmbeddingPipelineConfig(PipelineConfig):
)
"""

encoder_model: str = "text_sonar_basic_encoder"
fbank_dtype: torch.dtype = torch.float32
n_parallel: int = 4
Expand Down Expand Up @@ -107,7 +108,7 @@ def __init__(self, config: HFAudioToEmbeddingPipelineConfig):
self.model = SpeechToEmbeddingModelPipeline(
encoder=self.config.encoder_model,
device=torch.device(self.config.device),
fbank_dtype=self.config.fbank_dtype
fbank_dtype=self.config.fbank_dtype,
)

def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -133,8 +134,7 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
try:
for column in self.config.columns:
if column not in batch:
logger.warning(
f"Column {column} not found in batch. Skipping.")
logger.warning(f"Column {column} not found in batch. Skipping.")
continue

audio_inputs = []
Expand All @@ -145,9 +145,13 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
audio_data_list = [audio_data_list]

for audio_data in audio_data_list:
if isinstance(audio_data, dict) and 'array' in audio_data and 'sampling_rate' in audio_data:
if (
isinstance(audio_data, dict)
and "array" in audio_data
and "sampling_rate" in audio_data
):
# Handle multi-channel audio by taking the mean across channels
audio_array = audio_data['array']
audio_array = audio_data["array"]
if audio_array.ndim > 1:
audio_array = np.mean(audio_array, axis=0)

Expand All @@ -159,52 +163,60 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_tensor.dim() > 2:
raise ValueError(
f"Unexpected audio tensor shape: {audio_tensor.shape}")
f"Unexpected audio tensor shape: {audio_tensor.shape}"
)

audio_inputs.append(audio_tensor)
else:
logger.warning(
f"Invalid audio data format in column {column}: {audio_data}")
f"Invalid audio data format in column {column}: {audio_data}"
)

if not audio_inputs:
logger.warning(
f"No valid audio inputs found in column {column}.")
logger.warning(f"No valid audio inputs found in column {column}.")
continue

try:
# Move tensors to the specified device
audio_inputs = [tensor.to(self.config.device)
for tensor in audio_inputs]
audio_inputs = [
tensor.to(self.config.device) for tensor in audio_inputs
]

all_embeddings: torch.Tensor = self.model.predict(
input=audio_inputs,
batch_size=self.config.batch_size,
n_parallel=self.config.n_parallel,
pad_idx=self.config.pad_idx
pad_idx=self.config.pad_idx,
)

# Ensure all embeddings are 2D
processed_embeddings: List[torch.Tensor] = [emb.unsqueeze(0) if emb.dim(
) == 1 else emb for emb in all_embeddings]
processed_embeddings: List[torch.Tensor] = [
emb.unsqueeze(0) if emb.dim() == 1 else emb
for emb in all_embeddings
]

# Get the maximum sequence length
max_seq_len = max(emb.shape[0]
for emb in processed_embeddings)
max_seq_len = max(emb.shape[0] for emb in processed_embeddings)

# Pad embeddings to have the same sequence length
padded_embeddings = [torch.nn.functional.pad(
emb, (0, 0, 0, max_seq_len - emb.shape[0])) for emb in processed_embeddings]
padded_embeddings = [
torch.nn.functional.pad(
emb, (0, 0, 0, max_seq_len - emb.shape[0])
)
for emb in processed_embeddings
]

# Stack embeddings into a single tensor
stacked_embeddings = torch.stack(
padded_embeddings).unsqueeze(1)
stacked_embeddings = torch.stack(padded_embeddings).unsqueeze(1)

batch[f"{column}_{self.config.output_column_suffix}"] = stacked_embeddings.cpu(
).numpy()
batch[f"{column}_{self.config.output_column_suffix}"] = (
stacked_embeddings.cpu().numpy()
)

except Exception as e:
logger.error(
f"Error in model.predict for column {column}: {str(e)}")
f"Error in model.predict for column {column}: {str(e)}"
)
# Instead of raising, we'll set the output to None and continue processing
batch[f"{column}_{self.config.output_column_suffix}"] = None

Expand Down
66 changes: 41 additions & 25 deletions tests/unit_tests/huggingface_pipelines/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ class MockSpeechToEmbeddingModelPipeline(SpeechToEmbeddingModelPipeline):
def __init__(self, encoder: Any, device: Any, fbank_dtype: Any):
pass

def predict(self, input: Union[Sequence[str], Sequence[torch.Tensor]],
batch_size: int = 32,
n_parallel: int = 1,
pad_idx: int = 0,
n_prefetched_batches: int = 1,
progress_bar: bool = False) -> torch.Tensor:
def predict(
self,
input: Union[Sequence[str], Sequence[torch.Tensor]],
batch_size: int = 32,
n_parallel: int = 1,
pad_idx: int = 0,
n_prefetched_batches: int = 1,
progress_bar: bool = False,
) -> torch.Tensor:
return torch.stack([torch.tensor([[0.1, 0.2, 0.3]]) for _ in input])


Expand All @@ -35,16 +38,13 @@ def pipeline_config():
audio_column="audio",
columns=["test"],
output_path="test",
output_column_suffix="test_embeddings"
output_column_suffix="test_embeddings",
)


@pytest.fixture
def sample_audio_data():
return {
'array': np.random.rand(16000),
'sampling_rate': 16000
}
return {"array": np.random.rand(16000), "sampling_rate": 16000}


def test_pipeline_initialization(pipeline_config, mock_speech_to_embedding_model):
Expand All @@ -53,15 +53,20 @@ def test_pipeline_initialization(pipeline_config, mock_speech_to_embedding_model
assert isinstance(pipeline.model, SpeechToEmbeddingModelPipeline)


def test_process_batch_valid_input(pipeline_config, mock_speech_to_embedding_model, sample_audio_data):
def test_process_batch_valid_input(
pipeline_config, mock_speech_to_embedding_model, sample_audio_data
):
pipeline = HFAudioToEmbeddingPipeline(pipeline_config)
batch: Dict[str, List[Dict[str, Any]]] = {
"audio": [sample_audio_data, sample_audio_data]}
"audio": [sample_audio_data, sample_audio_data]
}
result = pipeline.process_batch(batch)
assert "audio_embedding" in result
assert isinstance(result["audio_embedding"], np.ndarray)
assert result["audio_embedding"].shape == (
2, 3) # 2 samples, 3 embedding dimensions
2,
3,
) # 2 samples, 3 embedding dimensions


def test_process_batch_empty_input(pipeline_config, mock_speech_to_embedding_model):
Expand All @@ -71,41 +76,52 @@ def test_process_batch_empty_input(pipeline_config, mock_speech_to_embedding_mod
assert "audio_embedding" not in result


def test_process_batch_invalid_audio_data(pipeline_config, mock_speech_to_embedding_model):
def test_process_batch_invalid_audio_data(
pipeline_config, mock_speech_to_embedding_model
):
pipeline = HFAudioToEmbeddingPipeline(pipeline_config)
batch: Dict[str, List[Dict[str, Any]]] = {"audio": [{"invalid": "data"}]}
result = pipeline.process_batch(batch)
assert "audio_embedding" not in result


def test_process_batch_mixed_valid_invalid_data(pipeline_config, mock_speech_to_embedding_model, sample_audio_data):
def test_process_batch_mixed_valid_invalid_data(
pipeline_config, mock_speech_to_embedding_model, sample_audio_data
):
pipeline = HFAudioToEmbeddingPipeline(pipeline_config)
batch: Dict[str, List[Dict[str, Any]]] = {"audio": [sample_audio_data, {
"invalid": "data"}, sample_audio_data]}
batch: Dict[str, List[Dict[str, Any]]] = {
"audio": [sample_audio_data, {"invalid": "data"}, sample_audio_data]
}
result = pipeline.process_batch(batch)
assert "audio_embedding" in result
assert isinstance(result["audio_embedding"], np.ndarray)
# 2 valid samples, 3 embedding dimensions
assert result["audio_embedding"].shape == (2, 3)


@patch('huggingface_pipelines.speech.SpeechToEmbeddingModelPipeline')
def test_error_handling_in_model_predict(mock_predict, pipeline_config, sample_audio_data):
mock_predict.return_value.predict.side_effect = Exception(
"Model prediction error")
@patch("huggingface_pipelines.speech.SpeechToEmbeddingModelPipeline")
def test_error_handling_in_model_predict(
mock_predict, pipeline_config, sample_audio_data
):
mock_predict.return_value.predict.side_effect = Exception("Model prediction error")
pipeline = HFAudioToEmbeddingPipeline(pipeline_config)
batch: Dict[str, List[Dict[str, Any]]] = {"audio": [sample_audio_data]}

with pytest.raises(Exception, match="Model prediction error"):
pipeline.process_batch(batch)


def test_process_large_batch(pipeline_config, mock_speech_to_embedding_model, sample_audio_data):
def test_process_large_batch(
pipeline_config, mock_speech_to_embedding_model, sample_audio_data
):
pipeline = HFAudioToEmbeddingPipeline(pipeline_config)
large_batch: Dict[str, List[Dict[str, Any]]] = {
"audio": [sample_audio_data] * 100} # 100 audio samples
"audio": [sample_audio_data] * 100
} # 100 audio samples
result = pipeline.process_batch(large_batch)
assert "audio_embedding" in result
assert isinstance(result["audio_embedding"], np.ndarray)
assert result["audio_embedding"].shape == (
100, 3) # 100 samples, 3 embedding dimensions
100,
3,
) # 100 samples, 3 embedding dimensions

0 comments on commit a461e08

Please sign in to comment.