diff --git a/mteb/models/instruct_wrapper.py b/mteb/models/instruct_wrapper.py index 2ee3a09b5..3feb8b0a2 100644 --- a/mteb/models/instruct_wrapper.py +++ b/mteb/models/instruct_wrapper.py @@ -6,6 +6,7 @@ import numpy as np import torch +from sentence_transformers import SentenceTransformer from mteb.encoder_interface import PromptType @@ -78,3 +79,75 @@ def encode( return embeddings return InstructWrapper(model_name_or_path, mode, instruction_template, **kwargs) + + +class InstructSentenceTransformerWrapper(Wrapper): + def __init__( + self, + model_name: str, + revision: str, + instruction_template: str | Callable[[str], str] | None = None, + max_seq_length: int | None = None, + apply_instruction_to_passages: bool = True, + **kwargs: Any, + ): + """ + Instruct Sentence Transformer Wrapper. Wrapper that passes instructions to the Sentence Transformer model. + Applied for models like gte-Qwen, e5-mistral, etc. + + Arguments: + model_name: Model name + revision: Revision + instruction_template: Model template. Should contain the string '{instruction}'. + max_seq_length: Maximum sequence length. If None, the maximum sequence length won't be changed. + apply_instruction_to_passages: Whether to apply the instruction template to the passages. + **kwargs: + """ + if ( + isinstance(instruction_template, str) + and "{instruction}" not in instruction_template + ): + raise ValueError( + "Instruction template must contain the string '{instruction}'." + ) + if instruction_template is None: + logger.warning( + "No instruction template provided. Instructions will be used as-is." + ) + + self.model_name = model_name + self.model = SentenceTransformer(model_name, revision=revision, **kwargs) + self.instruction_template = instruction_template + self.apply_instruction_to_passages = apply_instruction_to_passages + if max_seq_length is not None: + self.model.max_seq_length = max_seq_length + + def encode( + self, + sentences: Sequence[str], + *, + task_name: str, + prompt_type: PromptType | None = None, + **kwargs: Any, + ) -> np.ndarray: + instruction = self.get_task_instruction(task_name, prompt_type) + + # to passage prompts won't be applied to passages + if not self.apply_instruction_to_passages and prompt_type == PromptType.passage: + instruction = None + logger.info( + f"No instruction used, because prompt type = {prompt_type.passage}" + ) + + if instruction: + logger.info(f"Using instruction: '{instruction}' for task: '{task_name}'") + embeddings = self.model.encode( + sentences, + prompt=instruction, + **kwargs, + ) + + if isinstance(embeddings, torch.Tensor): + # sometimes in kwargs can be return_tensors=True + embeddings = embeddings.cpu().detach().float().numpy() + return embeddings diff --git a/mteb/models/jasper_models.py b/mteb/models/jasper_models.py index 60fa4f697..34f09e550 100644 --- a/mteb/models/jasper_models.py +++ b/mteb/models/jasper_models.py @@ -1,68 +1,20 @@ from __future__ import annotations import logging -from collections.abc import Sequence from functools import partial -from typing import Any, Callable -import numpy as np import torch -from sentence_transformers import SentenceTransformer -import mteb -from mteb.encoder_interface import PromptType from mteb.model_meta import ModelMeta -from .wrapper import Wrapper +from .instruct_wrapper import InstructSentenceTransformerWrapper logger = logging.getLogger(__name__) -class JasperWrapper(Wrapper): - def __init__( - self, - model_name: str, - revision: str, - instruction_template: str | Callable[[str], str] | None = None, - max_seq_length: int = 2048, - **kwargs: Any, - ): - self.model_name = model_name - self.model = SentenceTransformer(model_name, revision=revision, **kwargs) - self.instruction_template = instruction_template - self.model.max_seq_length = max_seq_length - - def encode( - self, - sentences: Sequence[str], - *, - task_name: str, - prompt_type: PromptType | None = None, - **kwargs: Any, - ) -> np.ndarray: - task = mteb.get_task(task_name=task_name) - instruction = self.get_task_instruction(task_name, prompt_type) - - # to passage prompts won't be applied to passages - if prompt_type == PromptType.passage and task.metadata.type == "s2p": - instruction = None - - embeddings = self.model.encode( - sentences, - normalize_embeddings=True, - prompt=instruction, - **kwargs, - ) - - if isinstance(embeddings, torch.Tensor): - # sometimes in kwargs can be return_tensors=True - embeddings = embeddings.cpu().detach().float().numpy() - return embeddings - - jasper_en_v1 = ModelMeta( loader=partial( # type: ignore - JasperWrapper, + InstructSentenceTransformerWrapper, model_name="infgrad/jasper_en_vision_language_v1", revision="d6330ce98f8a0d741e781df845904c9484f00efa", config_kwargs={"is_text_encoder": True, "vector_dim": 12288},