Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add sentence trimming to OpenAIWrapper #1526

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
48 changes: 42 additions & 6 deletions mteb/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,48 @@


class OpenAIWrapper(Wrapper):
def __init__(self, model_name: str, embed_dim: int | None = None, **kwargs) -> None:
def __init__(
self,
model_name: str,
max_tokens: int,
tokenizer_name: str = "cl100k_base", # since all models use this tokenizer now
embed_dim: int | None = None,
**kwargs,
) -> None:
requires_package(self, "openai", "Openai text embedding")
from openai import OpenAI

self._client = OpenAI()
self._model_name = model_name
self._embed_dim = embed_dim
self._max_tokens = max_tokens
self._tokenizer_name = tokenizer_name

def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray:
requires_package(self, "openai", "Openai text embedding")
requires_package(self, "tiktoken", "Tiktoken package")
import tiktoken
Samoed marked this conversation as resolved.
Show resolved Hide resolved
from openai import NotGiven

if self._model_name == "text-embedding-ada-002" and self._embed_dim is not None:
logger.warning(
"Reducing embedding size available only for text-embedding-3-* models"
)

trimmed_sentences = []
for sentence in sentences:
encoding = tiktoken.get_encoding(self._tokenizer_name)
encoded_sentence = encoding.encode(sentence)
if len(encoded_sentence) > self._max_tokens:
trimmed_sentence = encoding.decode(encoded_sentence[: self._max_tokens])
trimmed_sentences.append(trimmed_sentence)
else:
trimmed_sentences.append(sentence)

max_batch_size = 2048
sublists = [
sentences[i : i + max_batch_size]
for i in range(0, len(sentences), max_batch_size)
trimmed_sentences[i : i + max_batch_size]
for i in range(0, len(trimmed_sentences), max_batch_size)
]

all_embeddings = []
Expand All @@ -60,7 +81,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
revision="1",
release_date="2024-01-25",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-3-small"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-3-small",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=1536,
open_weights=False,
Expand All @@ -77,7 +103,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
revision="1",
release_date="2024-01-25",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-3-large"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-3-large",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=3072,
open_weights=False,
Expand All @@ -91,7 +122,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray:
revision="1",
release_date="2022-12-15",
languages=None, # supported languages not specified
loader=partial(OpenAIWrapper, model_name="text-embedding-ada-002"),
loader=partial(
OpenAIWrapper,
model_name="text-embedding-ada-002",
tokenizer_name="cl100k_base",
max_tokens=8192,
),
max_tokens=8191,
embed_dim=1536,
open_weights=False,
Expand Down