diff --git a/src/openparse/processing/semantic_transforms.py b/src/openparse/processing/semantic_transforms.py index 8369035..6af867c 100644 --- a/src/openparse/processing/semantic_transforms.py +++ b/src/openparse/processing/semantic_transforms.py @@ -37,27 +37,26 @@ def __init__( self.batch_size = batch_size self.client = self._create_client() - def embed_many(self, texts: List[str]) -> List[List[float]]: - """ - Generate embeddings for a list of texts in batches. - - Args: - texts (list[str]): The list of texts to embed. - batch_size (int): The number of texts to process in each batch. - - Returns: - List[List[float]]: A list of embeddings. - """ + def embed_many(self, texts: list[str]) -> list[list[float]]: res = [] - for i in range(0, len(texts), self.batch_size): - batch_texts = texts[i : i + self.batch_size] + non_empty_texts = [text for text in texts if text] + + embedding_size = 1 + for i in range(0, len(non_empty_texts), self.batch_size): + batch_texts = non_empty_texts[i : i + self.batch_size] api_resp = self.client.embeddings.create( input=batch_texts, model=self.model ) batch_res = [val.embedding for val in api_resp.data] res.extend(batch_res) + embedding_size = len(batch_res[0]) + + # Map results back to original indices, adding zero embeddings for empty texts + final_res = [ + [0.0] * embedding_size if not text else res.pop(0) for text in texts + ] - return res + return final_res def _create_client(self): try: