-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: add sentence trimming to OpenAIWrapper #1526
base: main
Are you sure you want to change the base?
Changes from all commits
355b3b8
32fe482
83ea742
f76e1c3
21a2937
0b6a2d9
43e2463
d58c84b
ec44a02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,27 +15,48 @@ | |
|
||
|
||
class OpenAIWrapper(Wrapper): | ||
def __init__(self, model_name: str, embed_dim: int | None = None, **kwargs) -> None: | ||
def __init__( | ||
self, | ||
model_name: str, | ||
max_tokens: int, | ||
tokenizer_name: str = "cl100k_base", # since all models use this tokenizer now | ||
embed_dim: int | None = None, | ||
**kwargs, | ||
) -> None: | ||
requires_package(self, "openai", "Openai text embedding") | ||
from openai import OpenAI | ||
|
||
self._client = OpenAI() | ||
self._model_name = model_name | ||
self._embed_dim = embed_dim | ||
self._max_tokens = max_tokens | ||
self._tokenizer_name = tokenizer_name | ||
|
||
def encode(self, sentences: list[str], **kwargs: Any) -> np.ndarray: | ||
requires_package(self, "openai", "Openai text embedding") | ||
requires_package(self, "tiktoken", "Tiktoken package") | ||
import tiktoken | ||
Samoed marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from openai import NotGiven | ||
|
||
if self._model_name == "text-embedding-ada-002" and self._embed_dim is not None: | ||
logger.warning( | ||
"Reducing embedding size available only for text-embedding-3-* models" | ||
) | ||
|
||
trimmed_sentences = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks good. I would add a docstring to the model explaining this. I could e.g. imagine another alternative of embedding sequences and averaging the embedding and it is nice to be clear about what implementation choices we have made (this is in line with sentence transformers as far as I know). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean adding a log explaining about trimming the sequence length ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No just a docstring for the class. Something like: To handle documents larger than XXX we truncate the document to the specified sequence length. |
||
for sentence in sentences: | ||
encoding = tiktoken.get_encoding(self._tokenizer_name) | ||
encoded_sentence = encoding.encode(sentence) | ||
if len(encoded_sentence) > self._max_tokens: | ||
trimmed_sentence = encoding.decode(encoded_sentence[: self._max_tokens]) | ||
trimmed_sentences.append(trimmed_sentence) | ||
else: | ||
trimmed_sentences.append(sentence) | ||
|
||
max_batch_size = 2048 | ||
sublists = [ | ||
sentences[i : i + max_batch_size] | ||
for i in range(0, len(sentences), max_batch_size) | ||
trimmed_sentences[i : i + max_batch_size] | ||
for i in range(0, len(trimmed_sentences), max_batch_size) | ||
] | ||
|
||
all_embeddings = [] | ||
|
@@ -60,7 +81,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray: | |
revision="1", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we changed the implementation we should probably change the revision for all models. (bump it up to "2") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could keep the old version it is nice for reproducibility, but not something that we want to maintain so I wouldn't do it. |
||
release_date="2024-01-25", | ||
languages=None, # supported languages not specified | ||
loader=partial(OpenAIWrapper, model_name="text-embedding-3-small"), | ||
loader=partial( | ||
OpenAIWrapper, | ||
model_name="text-embedding-3-small", | ||
tokenizer_name="cl100k_base", | ||
max_tokens=8192, | ||
), | ||
max_tokens=8191, | ||
embed_dim=1536, | ||
open_weights=False, | ||
|
@@ -77,7 +103,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray: | |
revision="1", | ||
release_date="2024-01-25", | ||
languages=None, # supported languages not specified | ||
loader=partial(OpenAIWrapper, model_name="text-embedding-3-large"), | ||
loader=partial( | ||
OpenAIWrapper, | ||
model_name="text-embedding-3-large", | ||
tokenizer_name="cl100k_base", | ||
max_tokens=8192, | ||
), | ||
max_tokens=8191, | ||
embed_dim=3072, | ||
open_weights=False, | ||
|
@@ -91,7 +122,12 @@ def _to_numpy(self, embedding_response) -> np.ndarray: | |
revision="1", | ||
release_date="2022-12-15", | ||
languages=None, # supported languages not specified | ||
loader=partial(OpenAIWrapper, model_name="text-embedding-ada-002"), | ||
loader=partial( | ||
OpenAIWrapper, | ||
model_name="text-embedding-ada-002", | ||
tokenizer_name="cl100k_base", | ||
max_tokens=8192, | ||
), | ||
max_tokens=8191, | ||
embed_dim=1536, | ||
open_weights=False, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add this to the pyproject.toml to ensure that version are kept compatible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably bump them together such that have them together in al optional dependency set:
pip install mteb[openai]
@Samoed what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(we should probably also change requires_package) to have "tiktoken", "tiktoken package", "{dependency group}").
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree