Skip to content

Commit

Permalink
Introducting VoyageAI's new multimodal embeddings model (#17261)
Browse files Browse the repository at this point in the history
  • Loading branch information
fzowl authored Dec 15, 2024
1 parent c3fd72a commit f69fba1
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
"""Voyage embeddings file."""

import logging
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from llama_index.core.base.embeddings.base import Embedding
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.callbacks.base import CallbackManager

import voyageai
from llama_index.core.embeddings import MultiModalEmbedding
import base64
from io import BytesIO
from pathlib import Path
from llama_index.core.schema import ImageType
from PIL import Image


logger = logging.getLogger(__name__)

MULTIMODAL_MODELS = ["voyage-multimodal-3"]

SUPPORTED_IMAGE_FORMATS = {"png", "jpeg", "jpg", "webp", "gif"}


class VoyageEmbedding(BaseEmbedding):
class VoyageEmbedding(MultiModalEmbedding):
"""Class for Voyage embeddings.
Args:
Expand Down Expand Up @@ -68,60 +78,150 @@ def __init__(
def class_name(cls) -> str:
return "VoyageEmbedding"

def _get_embedding(self, texts: List[str], input_type: str) -> List[List[float]]:
return self._client.embed(
texts,
@staticmethod
def _validate_image_format(file_type: str) -> bool:
"""Validate image format."""
return file_type.lower() in SUPPORTED_IMAGE_FORMATS

def _text_to_content(self, input_str: str) -> dict:
return {"type": "text", "text": input_str}

def _texts_to_content(self, input_strs: List[str]) -> List[dict]:
return [{"content": [{"type": "text", "text": x}]} for x in input_strs]

def _image_to_content(self, image_input: Union[str, Path, BytesIO]) -> dict:
"""Convert an image to a base64 Data URL."""
if isinstance(image_input, (str, Path)):
# If it's a string or Path, assume it's a file path
content = {"type": "image_url", "image_url": image_input}
elif isinstance(image_input, BytesIO):
# If it's a BytesIO, use it directly
image = Image.open(image_input)
file_extension = image.format.lower()
image_input.seek(0) # Reset the BytesIO stream to the beginning
image_data = image_input.read()

if self._validate_image_format(file_extension):
enc_img = base64.b64encode(image_data).decode("utf-8")
content = {
"type": "image_base64",
"image_base64": f"data:image/{file_extension};base64,{enc_img}",
}
else:
raise ValueError(f"Unsupported image format: {file_extension}")
else:
raise ValueError("Unsupported input type. Must be a file path or BytesIO.")

return {"content": [content]}

def _embed_image(
self, image_path: ImageType, input_type: Optional[str] = None
) -> List[float]:
"""Embed images using VoyageAI."""
if self.model_name not in MULTIMODAL_MODELS:
raise ValueError(
f"{self.model_name} is not a valid multi-modal embedding model. Supported models are {MULTIMODAL_MODELS}"
)
processed_image = self._image_to_content(image_path)
return self._client.multimodal_embed(
model=self.model_name,
inputs=[processed_image],
input_type=input_type,
truncation=self.truncation,
).embeddings
).embeddings[0]

async def _aget_embedding(
self, texts: List[str], input_type: str
) -> List[List[float]]:
r = await self._aclient.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
)
async def _aembed_image(
self, image_path: ImageType, input_type: Optional[str] = None
) -> List[float]:
"""Embed images using VoyageAI."""
if self.model_name not in MULTIMODAL_MODELS:
raise ValueError(
f"{self.model_name} is not a valid multi-modal embedding model. Supported models are {MULTIMODAL_MODELS}"
)
processed_image = self._image_to_content(image_path)
return (
await self._aclient.multimodal_embed(
model=self.model_name,
inputs=[processed_image],
input_type=input_type,
truncation=self.truncation,
)
).embeddings[0]

def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:
return self._embed_image(img_file_path)

async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
return await self._aembed_image(img_file_path)

def _embed(self, texts: List[str], input_type: str) -> List[List[float]]:
if self.model_name in MULTIMODAL_MODELS:
return self._client.multimodal_embed(
inputs=self._texts_to_content(texts),
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
).embeddings
else:
return self._client.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
).embeddings

async def _aembed(self, texts: List[str], input_type: str) -> List[List[float]]:
if self.model_name in MULTIMODAL_MODELS:
r = self._aclient.multimodal_embed(
inputs=self._texts_to_content(texts),
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
)
else:
r = await self._aclient.embed(
texts,
model=self.model_name,
input_type=input_type,
truncation=self.truncation,
)
return r.embeddings

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_embedding([query], input_type="query")[0]
return self._embed([query], input_type="query")[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
r = await self._aget_embedding([query], input_type="query")
r = await self._aembed([query], input_type="query")
return r[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_embedding([text], input_type="document")[0]
return self._embed([text], input_type="document")[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
r = await self._aget_embedding([text], input_type="document")
r = await self._aembed([text], input_type="document")
return r[0]

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
return self._get_embedding(texts, input_type="document")
return self._embed(texts, input_type="document")

async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return await self._aget_embedding(texts, input_type="document")
return await self._aembed(texts, input_type="document")

def get_general_text_embedding(
self, text: str, input_type: Optional[str] = None
) -> List[float]:
"""Get general text embedding with input_type."""
return self._get_embedding([text], input_type=input_type)[0]
return self._embed([text], input_type=input_type)[0]

async def aget_general_text_embedding(
self, text: str, input_type: Optional[str] = None
) -> List[float]:
"""Asynchronously get general text embedding with input_type."""
r = await self._aget_embedding([text], input_type=input_type)
r = await self._aembed([text], input_type=input_type)
return r[0]
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-voyageai"
readme = "README.md"
version = "0.3.1"
version = "0.3.2"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
voyageai = {python = ">=3.9,<3.13", version = ">=0.2.1,<0.4.0"}
voyageai = {python = ">=3.9,<3.13", version = ">=0.3.2,<0.4.0"}
llama-index-core = "^0.12.0"

[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit f69fba1

Please sign in to comment.