Skip to content

Commit

Permalink
Fix black linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
botirk38 committed Aug 23, 2024
1 parent fc6bf4e commit 0d5ef9b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 61 deletions.
69 changes: 38 additions & 31 deletions huggingface_pipelines/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TextDatasetConfig(DatasetConfig):
This class inherits from BaseDatasetConfig and can be used for
text-specific dataset configurations.
"""

pass


Expand All @@ -51,9 +52,10 @@ class TextSegmentationPipelineConfig(PipelineConfig):
handle_missing='fill'
)
"""

fill_value: Optional[str] = None
source_lang: str = "eng_Latn"
handle_missing: Literal['skip', 'remove', 'fill'] = "skip"
handle_missing: Literal["skip", "remove", "fill"] = "skip"


class TextSegmentationPipeline(Pipeline):
Expand Down Expand Up @@ -112,8 +114,7 @@ def load_spacy_model(self, lang_code: str) -> Language:
nlp = pipeline.load_spacy_model('en')
"""
if lang_code not in self.SPACY_MODELS:
raise ValueError(
f"No installed model found for language code: {lang_code}")
raise ValueError(f"No installed model found for language code: {lang_code}")
return spacy.load(self.SPACY_MODELS[lang_code])

def segment_text(self, text: Optional[str]) -> List[str]:
Expand All @@ -133,16 +134,17 @@ def segment_text(self, text: Optional[str]) -> List[str]:
sentences = pipeline.segment_text("This is a sample. It has two sentences.")
print(sentences) # ['This is a sample.', 'It has two sentences.']
"""
if text is None or (isinstance(text, str) and text.strip() == ''):
if self.config.handle_missing == 'skip':
if text is None or (isinstance(text, str) and text.strip() == ""):
if self.config.handle_missing == "skip":
return []
elif self.config.handle_missing == 'remove':
elif self.config.handle_missing == "remove":
return []
elif self.config.handle_missing == 'fill':
elif self.config.handle_missing == "fill":
return [self.config.fill_value] if self.config.fill_value else []
else:
raise ValueError(
f"Invalid handle_missing option: {self.config.handle_missing}")
f"Invalid handle_missing option: {self.config.handle_missing}"
)

doc = self.nlp(text)
return [sent.text.strip() for sent in doc.sents]
Expand All @@ -165,7 +167,8 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
for column in self.config.columns:
if column in batch:
batch[f"{column}_preprocessed"] = [
self.segment_text(text) for text in batch[column]]
self.segment_text(text) for text in batch[column]
]
return batch


Expand Down Expand Up @@ -218,6 +221,7 @@ class TextToEmbeddingPipelineConfig(PipelineConfig):
max_seq_len=512
)
"""

max_seq_len: Optional[int] = None
encoder_model: str = "text_sonar_basic_encoder"
source_lang: str = "eng_Latn"
Expand Down Expand Up @@ -246,6 +250,7 @@ class EmbeddingToTextPipelineConfig(PipelineConfig):
device="cuda"
)
"""

decoder_model: str = "text_sonar_basic_decoder"
target_lang: str = "eng_Latn"

Expand Down Expand Up @@ -274,7 +279,7 @@ def __init__(self, config: EmbeddingToTextPipelineConfig):
self.t2t_model = EmbeddingToTextModelPipeline(
decoder=self.config.decoder_model,
tokenizer=self.config.decoder_model,
device=self.config.device
device=self.config.device,
)
logger.info("Model initialized.")

Expand All @@ -290,19 +295,20 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, List[str]]:
"""
for column in self.config.columns:
embeddings: List[np.ndarray] = batch[column]
assert all(isinstance(item, list) for item in embeddings), \
f"Column {column} must contain only lists of embeddings, not individual embeddings."
assert all(
isinstance(item, list) for item in embeddings
), f"Column {column} must contain only lists of embeddings, not individual embeddings."

all_embeddings = np.vstack([np.array(embed)
for item in embeddings for embed in item])
all_embeddings = np.vstack(
[np.array(embed) for item in embeddings for embed in item]
)
all_decoded_texts = self.decode_embeddings(all_embeddings)

reconstructed_texts = []
start_idx = 0
for item in embeddings:
end_idx = start_idx + len(item)
reconstructed_texts.append(
all_decoded_texts[start_idx:end_idx])
reconstructed_texts.append(all_decoded_texts[start_idx:end_idx])
start_idx = end_idx

batch[f"{column}_{self.config.output_column_suffix}"] = reconstructed_texts
Expand Down Expand Up @@ -334,12 +340,11 @@ def decode_embeddings(self, embeddings: np.ndarray) -> List[str]:
decoded_texts = []

for i in range(0, len(embeddings), self.config.batch_size):
batch_embeddings = embeddings_tensor[i:i +
self.config.batch_size]
batch_embeddings = embeddings_tensor[i : i + self.config.batch_size]
batch_decoded = self.t2t_model.predict(
batch_embeddings,
target_lang=self.config.target_lang,
batch_size=self.config.batch_size
batch_size=self.config.batch_size,
)
decoded_texts.extend(batch_decoded)

Expand Down Expand Up @@ -407,7 +412,7 @@ def __init__(self, config: TextToEmbeddingPipelineConfig):
self.t2vec_model = TextToEmbeddingModelPipeline(
encoder=self.config.encoder_model,
tokenizer=self.config.encoder_model,
device=self.config.device
device=self.config.device,
)

def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -422,24 +427,26 @@ def process_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""
for column in self.config.columns:
if column in batch:
assert all(isinstance(item, list) for item in batch[column]), \
f"Column {column} must contain only lists of sentences, not individual strings."
assert all(
isinstance(item, list) for item in batch[column]
), f"Column {column} must contain only lists of sentences, not individual strings."

all_sentences = [sentence for item in batch[column]
for sentence in item]
all_sentences = [
sentence for item in batch[column] for sentence in item
]
all_embeddings = self.encode_texts(all_sentences)

sentence_embeddings = []
start_idx = 0
for item in batch[column]:
end_idx = start_idx + len(item)
sentence_embeddings.append(
all_embeddings[start_idx:end_idx])
sentence_embeddings.append(all_embeddings[start_idx:end_idx])
start_idx = end_idx

batch[f"{column}_{self.config.output_column_suffix}"] = sentence_embeddings
logger.debug(
f"{column} column embeddings: {batch[column][:5]}")
batch[f"{column}_{self.config.output_column_suffix}"] = (
sentence_embeddings
)
logger.debug(f"{column} column embeddings: {batch[column][:5]}")
else:
logger.warning(f"Column {column} not found in batch.")

Expand All @@ -461,12 +468,12 @@ def encode_texts(self, texts: List[str]) -> np.ndarray:
try:
embeddings: List[torch.Tensor] = []
for i in range(0, len(texts), self.config.batch_size):
batch_texts = texts[i:i + self.config.batch_size]
batch_texts = texts[i : i + self.config.batch_size]
batch_embeddings = self.t2vec_model.predict(
batch_texts,
source_lang=self.config.source_lang,
batch_size=self.config.batch_size,
max_seq_len=self.config.max_seq_len
max_seq_len=self.config.max_seq_len,
)

batch_embeddings = batch_embeddings.detach().cpu().numpy()
Expand Down
97 changes: 67 additions & 30 deletions tests/unit_tests/huggingface_pipelines/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def text_to_embedding_config():
batch_size=2,
device="cpu",
source_lang="eng_Latn",
output_path="test"
output_path="test",
)


Expand All @@ -52,79 +52,110 @@ def embedding_to_text_config():
batch_size=2,
device="cpu",
target_lang="eng_Latn",
output_path="test"
output_path="test",
)


@pytest.fixture
def mock_text_to_embedding_model():
with patch('huggingface_pipelines.text.TextToEmbeddingModelPipeline', MockTextToEmbeddingModelPipeline):
with patch(
"huggingface_pipelines.text.TextToEmbeddingModelPipeline",
MockTextToEmbeddingModelPipeline,
):
yield


@pytest.fixture
def mock_embedding_to_text_model():
with patch('huggingface_pipelines.text.EmbeddingToTextModelPipeline', MockEmbeddingToTextModelPipeline):
with patch(
"huggingface_pipelines.text.EmbeddingToTextModelPipeline",
MockEmbeddingToTextModelPipeline,
):
yield


def test_text_to_embedding_pipeline_initialization(text_to_embedding_config, mock_text_to_embedding_model):
def test_text_to_embedding_pipeline_initialization(
text_to_embedding_config, mock_text_to_embedding_model
):
pipeline = HFTextToEmbeddingPipeline(text_to_embedding_config)
assert pipeline.config == text_to_embedding_config
assert isinstance(pipeline.t2vec_model, MockTextToEmbeddingModelPipeline)


def test_embedding_to_text_pipeline_initialization(embedding_to_text_config, mock_embedding_to_text_model):
def test_embedding_to_text_pipeline_initialization(
embedding_to_text_config, mock_embedding_to_text_model
):
pipeline = HFEmbeddingToTextPipeline(embedding_to_text_config)
assert pipeline.config == embedding_to_text_config
assert isinstance(pipeline.t2t_model, MockEmbeddingToTextModelPipeline)


def test_text_to_embedding_process_batch(text_to_embedding_config, mock_text_to_embedding_model):
def test_text_to_embedding_process_batch(
text_to_embedding_config, mock_text_to_embedding_model
):
pipeline = HFTextToEmbeddingPipeline(text_to_embedding_config)
batch = {"text": [["Hello", "World"], ["Test", "Sentence"]]}
result = pipeline.process_batch(batch)
assert "text_embedding" in result
assert len(result["text_embedding"]) == 2
assert all(isinstance(item, np.ndarray)
for sublist in result["text_embedding"] for item in sublist)
assert all(
isinstance(item, np.ndarray)
for sublist in result["text_embedding"]
for item in sublist
)


def test_embedding_to_text_process_batch(embedding_to_text_config, mock_embedding_to_text_model):
def test_embedding_to_text_process_batch(
embedding_to_text_config, mock_embedding_to_text_model
):
pipeline = HFEmbeddingToTextPipeline(embedding_to_text_config)
batch = {"embedding": [
[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9]]]}
batch = {"embedding": [[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9]]]}
result = pipeline.process_batch(batch)
assert "embedding_text" in result
assert len(result["embedding_text"]) == 2
assert all(isinstance(text, list) for text in result["embedding_text"])
assert all(isinstance(item, str)
for sublist in result["embedding_text"] for item in sublist)
assert all(
isinstance(item, str)
for sublist in result["embedding_text"]
for item in sublist
)


@pytest.mark.parametrize("invalid_batch", [
{"text": "Not a list"},
{"text": [1, 2, 3]},
{"text": ["Not a list of lists"]},
])
def test_text_to_embedding_invalid_input(text_to_embedding_config, mock_text_to_embedding_model, invalid_batch):
@pytest.mark.parametrize(
"invalid_batch",
[
{"text": "Not a list"},
{"text": [1, 2, 3]},
{"text": ["Not a list of lists"]},
],
)
def test_text_to_embedding_invalid_input(
text_to_embedding_config, mock_text_to_embedding_model, invalid_batch
):
pipeline = HFTextToEmbeddingPipeline(text_to_embedding_config)
with pytest.raises(AssertionError):
pipeline.process_batch(invalid_batch)


@pytest.mark.parametrize("invalid_batch", [
{"embedding": "Not a list"},
{"embedding": [1, 2, 3]},
{"embedding": ["Not a list of lists"]},
])
def test_embedding_to_text_invalid_input(embedding_to_text_config, mock_embedding_to_text_model, invalid_batch):
@pytest.mark.parametrize(
"invalid_batch",
[
{"embedding": "Not a list"},
{"embedding": [1, 2, 3]},
{"embedding": ["Not a list of lists"]},
],
)
def test_embedding_to_text_invalid_input(
embedding_to_text_config, mock_embedding_to_text_model, invalid_batch
):
pipeline = HFEmbeddingToTextPipeline(embedding_to_text_config)
with pytest.raises(AssertionError):
pipeline.process_batch(invalid_batch)


def test_text_to_embedding_large_batch(text_to_embedding_config, mock_text_to_embedding_model):
def test_text_to_embedding_large_batch(
text_to_embedding_config, mock_text_to_embedding_model
):
pipeline = HFTextToEmbeddingPipeline(text_to_embedding_config)
large_batch = {"text": [["Hello"] * 100, ["World"] * 100]}
result = pipeline.process_batch(large_batch)
Expand All @@ -133,10 +164,16 @@ def test_text_to_embedding_large_batch(text_to_embedding_config, mock_text_to_em
assert all(len(emb) == 100 for emb in result["text_embedding"])


def test_embedding_to_text_large_batch(embedding_to_text_config, mock_embedding_to_text_model):
def test_embedding_to_text_large_batch(
embedding_to_text_config, mock_embedding_to_text_model
):
pipeline = HFEmbeddingToTextPipeline(embedding_to_text_config)
large_batch = {"embedding": [
[np.array([0.1, 0.2, 0.3])] * 100, [np.array([0.4, 0.5, 0.6])] * 100]}
large_batch = {
"embedding": [
[np.array([0.1, 0.2, 0.3])] * 100,
[np.array([0.4, 0.5, 0.6])] * 100,
]
}
result = pipeline.process_batch(large_batch)
assert "embedding_text" in result
assert len(result["embedding_text"]) == 2
Expand Down

0 comments on commit 0d5ef9b

Please sign in to comment.