Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Qwen2 embeddings and use tags to select model tests #10184

Merged
merged 20 commits into from
Nov 15, 2024

Conversation

DarkLight1337
Copy link
Member

@DarkLight1337 DarkLight1337 commented Nov 9, 2024

A newer version of #5611 and #6282 since the source repo has been archived.

FIX #5600
FIX #5827
FIX #6015
FIX #9761

Signed-off-by: DarkLight1337 <[email protected]>
Copy link

github-actions bot commented Nov 9, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 9, 2024
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
tests/models/embedding/language/test_embedding.py Outdated Show resolved Hide resolved
Comment on lines +457 to +460
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be able to be controlled by the pooling args we spoke about offline?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - these are the model's default values which can be overridden.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 9, 2024

I still need to update the tests since these models don't have sentence-transformers equivalents.

Nevermind, I was misled by the warning message. It's intended that ssmits/Qwen2-7B-Instruct-embed-base uses mean pooling instead of last pooling.

Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 9, 2024
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337
Copy link
Member Author

I have fixed the tests.

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: DarkLight1337 <[email protected]>
Copy link

mergify bot commented Nov 11, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @DarkLight1337.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 11, 2024
Signed-off-by: DarkLight1337 <[email protected]>
@mergify mergify bot removed the needs-rebase label Nov 11, 2024
@DarkLight1337
Copy link
Member Author

Any update on this?

@mergify mergify bot added the ci/build label Nov 13, 2024
Signed-off-by: DarkLight1337 <[email protected]>
@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 13, 2024

I have updated the tests so that Qwen2 embedding models are only tested on nightly. I have already tested them locally and confirmed them to pass.

@Ecocytus
Copy link

Is there any example usage of the Qwen2 embedding? The embeddings from hf and vllm cannot match with the official usage doc
Code to reproduce:

from typing import List, Sequence
import torch
import torch.nn.functional as F
from vllm import LLM
from sentence_transformers import SentenceTransformer

def check_embeddings_close(
    embeddings_0_lst: Sequence[List[float]],
    embeddings_1_lst: Sequence[List[float]],
    name_0: str,
    name_1: str,
    tol: float = 1e-2,
) -> None:
    assert len(embeddings_0_lst) == len(embeddings_1_lst)

    for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
            zip(embeddings_0_lst, embeddings_1_lst)):
        assert len(embeddings_0) == len(embeddings_1), (
            f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")

        sim = F.cosine_similarity(torch.tensor(embeddings_0),
                                torch.tensor(embeddings_1),
                                dim=0)

        fail_msg = (f"Test{prompt_idx}: {sim}")

        assert sim >= 1 - tol, fail_msg

model_name = 'Alibaba-NLP/gte-Qwen2-1.5B-instruct'
prefix_queries = ["hello world"]

model = LLM(model=model_name, task="embedding", dtype="float32", max_model_len=None)
outputs = model.encode(prefix_queries, use_tqdm=False)
query_embeddings1 = [output.outputs.embedding for output in outputs]

sen_model = SentenceTransformer(model_name, trust_remote_code=True, model_kwargs={
            "torch_dtype": "float32",
        }).cuda()

query_embeddings2 = sen_model.encode(prefix_queries, prompt='', show_progress_bar = False)

print(check_embeddings_close(query_embeddings1, query_embeddings2, "vllm", "sbert"))

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 22, 2024

Is there any example usage of the Qwen2 embedding? The embeddings from hf and vllm cannot match with the official usage doc

After some debugging, I found the problem - you need to set trust_remote_code=True when loading the model in vLLM.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 22, 2024

Hmm, this is really odd. I found that the EOS/pad token fails to be added to the prompt when trust_remote_code=False, which results from incorrect padding_side being set in the tokenizer. This can be minimally reproduced directly from transformers code. Just run:

>>> from transformers import AutoTokenizer
>>> AutoTokenizer.from_pretrained("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=False)
Qwen2TokenizerFast(name_or_path='Alibaba-NLP/gte-Qwen2-1.5B-instruct', vocab_size=151643, model_max_length=32768, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
        151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
>>> AutoTokenizer.from_pretrained("Alibaba-NLP/gte-Qwen2-1.5B-instruct", trust_remote_code=True)
Qwen2TokenizerFast(name_or_path='Alibaba-NLP/gte-Qwen2-1.5B-instruct', vocab_size=151643, model_max_length=32768, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'eos_token': '<|endoftext|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
        151643: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        151644: AddedToken("<|im_start|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
        151645: AddedToken("<|im_end|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

@DarkLight1337
Copy link
Member Author

Opened huggingface/transformers#34882

tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
@Diralpo
Copy link

Diralpo commented Nov 23, 2024

thanks for your excellent work! Does this mr support the bidirectional attention mechanisms in gte qwen2? I noticed that even manually adding eos in the input prompt, vllm will produce different embeddings than directly loading the original gte model.

@DarkLight1337
Copy link
Member Author

I got the test script by @Ecocytus to pass simply by adding trust_remote_code=True when initializing the vLLM model. Is that not the case for you?

@Diralpo
Copy link

Diralpo commented Nov 23, 2024

I got the test script by @Ecocytus to pass simply by adding trust_remote_code=True when initializing the vLLM model. Is that not the case for you?

I replace the model_name in @Ecocytus to my own path downloaded from huggingface. And I found the sim = 1 when I test gte-qwen2-1.5b , however it is 0.8498 when I test gte-qwen2-7b.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 23, 2024

From my understanding, vLLM does use bidirectional attention mask (please correct me if I'm wrong @mgoin). However, since we use the same model architecture as regular Qwen2 (Qwen2ForCausalLM), bidirectional attention is not enabled even for the embedding task. We cannot automatically use bidirectional attention because some embedding models might not use bidirectional attention. Perhaps we need to configure this through the CLI?

@Diralpo
Copy link

Diralpo commented Nov 23, 2024

gotcha! And I wonder whether bidirectional attention is enabled when loading the gte-qwen2-7b model, as the author mentioned

Integration of bidirectional attention mechanisms, enriching its contextual understanding.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 23, 2024

I just checked, Qwen2ForCausalLM uses decoder attention type by default which is causal attention. I think we can solve this by changing it to to encoder-only attention mask if the task is embedding.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 23, 2024

I just found that ssmits/Qwen2-7B-Instruct-embed-base has the same model but doesn't use bidirectional attention. So we still need to set it through CLI.

@Diralpo
Copy link

Diralpo commented Nov 23, 2024

I am a complete vlllm novice. Can I modify a few lines of code to force the loaded model to be encoder-only? I want to confirm whether the attention mask causes this embedding diff

@DarkLight1337
Copy link
Member Author

Actually, I am a bit suspicious about whether the attention mask is the real issue, since the 1.5B model is also supposed to use bidirectional attention mask yet works correctly with our decoder attention mask.

@DarkLight1337
Copy link
Member Author

You can try this patch

diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py
index c3f351ef..25cdfc81 100644
--- a/tests/models/embedding/language/test_embedding.py
+++ b/tests/models/embedding/language/test_embedding.py
@@ -19,8 +19,9 @@ from ..utils import check_embeddings_close
                      marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
         pytest.param("BAAI/bge-multilingual-gemma2",
                      marks=[pytest.mark.core_model]),
-        pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
-        pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
+        # pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
+        # pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
+        pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
     ],
 )
 @pytest.mark.parametrize("dtype", ["half"])
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 370cff5f..844e93a6 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -27,7 +27,7 @@ import torch
 from torch import nn
 from transformers import Qwen2Config
 
-from vllm.attention import Attention, AttentionMetadata
+from vllm.attention import Attention, AttentionMetadata, AttentionType
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -164,11 +164,13 @@ class Qwen2Attention(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
+        attn_type: str = AttentionType.DECODER,
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
-        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
+                                attn_type=attn_type)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -216,7 +218,8 @@ class Qwen2DecoderLayer(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
-        residual: Optional[torch.Tensor],
+        attn_type: str = AttentionType.DECODER,
+        residual: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         if residual is None:
@@ -230,6 +233,7 @@ class Qwen2DecoderLayer(nn.Module):
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             attn_metadata=attn_metadata,
+            attn_type=attn_type,
         )
 
         # Fully Connected
@@ -292,6 +296,12 @@ class Qwen2Model(nn.Module):
             self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         else:
             self.norm = PPMissingLayer()
+        
+        self._attn_type = {
+            "generate": AttentionType.DECODER,
+            "embedding": AttentionType.ENCODER_ONLY,
+            "draft": AttentionType.DECODER,
+        }[vllm_config.model_config.task]
 
     def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
         return self.embed_tokens(input_ids)
@@ -322,7 +332,8 @@ class Qwen2Model(nn.Module):
                 hidden_states,
                 kv_caches[i - self.start_layer],
                 attn_metadata,
-                residual,
+                attn_type=self._attn_type,
+                residual=residual,
             )
         if not get_pp_group().is_last_rank:
             return IntermediateTensors({

and then run the unit test

pytest tests/models/embedding/language/test_embedding.py -k gte-Qwen2-7B-instruct -x

@Diralpo
Copy link

Diralpo commented Nov 23, 2024

You can try this patch

diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py
index c3f351ef..25cdfc81 100644
--- a/tests/models/embedding/language/test_embedding.py
+++ b/tests/models/embedding/language/test_embedding.py
@@ -19,8 +19,9 @@ from ..utils import check_embeddings_close
                      marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
         pytest.param("BAAI/bge-multilingual-gemma2",
                      marks=[pytest.mark.core_model]),
-        pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
-        pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
+        # pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
+        # pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
+        pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
     ],
 )
 @pytest.mark.parametrize("dtype", ["half"])
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 370cff5f..844e93a6 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -27,7 +27,7 @@ import torch
 from torch import nn
 from transformers import Qwen2Config
 
-from vllm.attention import Attention, AttentionMetadata
+from vllm.attention import Attention, AttentionMetadata, AttentionType
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -164,11 +164,13 @@ class Qwen2Attention(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
+        attn_type: str = AttentionType.DECODER,
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         q, k = self.rotary_emb(positions, q, k)
-        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
+                                attn_type=attn_type)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -216,7 +218,8 @@ class Qwen2DecoderLayer(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
-        residual: Optional[torch.Tensor],
+        attn_type: str = AttentionType.DECODER,
+        residual: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         if residual is None:
@@ -230,6 +233,7 @@ class Qwen2DecoderLayer(nn.Module):
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             attn_metadata=attn_metadata,
+            attn_type=attn_type,
         )
 
         # Fully Connected
@@ -292,6 +296,12 @@ class Qwen2Model(nn.Module):
             self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         else:
             self.norm = PPMissingLayer()
+        
+        self._attn_type = {
+            "generate": AttentionType.DECODER,
+            "embedding": AttentionType.ENCODER_ONLY,
+            "draft": AttentionType.DECODER,
+        }[vllm_config.model_config.task]
 
     def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
         return self.embed_tokens(input_ids)
@@ -322,7 +332,8 @@ class Qwen2Model(nn.Module):
                 hidden_states,
                 kv_caches[i - self.start_layer],
                 attn_metadata,
-                residual,
+                attn_type=self._attn_type,
+                residual=residual,
             )
         if not get_pp_group().is_last_rank:
             return IntermediateTensors({

and then run the unit test

pytest tests/models/embedding/language/test_embedding.py -k gte-Qwen2-7B-instruct -x

For convenience, I don't run the pytest. But I changed the qwen2.py as mentioned, and ran the debug code. Now I got the sim = 1.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 23, 2024

Ok so the 1.5B model wasn't trained using bidirectional mask as advertised, only the 7B model 🤔

In any case, @mgoin how should we make the attention method configurable?

@DarkLight1337
Copy link
Member Author

Ok so the 1.5B model wasn't trained using bidirectional mask as advertised, only the 7B model 🤔

In any case, @mgoin how should we make the attention method configurable?

Also cc @youkaichao. Do you think adding this to model config is acceptable?

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 25, 2024

We need to support all of these in Qwen2Model:

  • Qwen/Qwen2-7B-Instruct (Qwen2ForCausalLM architecture, generation task, decoder-only attention)
  • Alibaba-NLP/gte-Qwen2-1.5B-instruct (Qwen2ForCausalLM architecture, embedding task, decoder-only attention)
  • Alibaba-NLP/gte-Qwen2-7B-instruct (Qwen2ForCausalLM architecture, embedding task, encoder-decoder attention)
  • ssmits/Qwen2-7B-Instruct-embed-base (Qwen2Model architecture, embedding task, decoder-only attention)
  • jason9693/Qwen2.5-1.5B-apeach (Qwen2ForSequenceClassification architecture, embedding task, decoder-only attention)

The problem now is the attention mask, which can be different even within the same task, so we can hardly set it automatically.

@youkaichao
Copy link
Member

My opinion: if it only occurs for QWen model, we can have an env var for it. If we find more models need it, we can add it to cli args.

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Nov 25, 2024

Ok I think I found a better way. We can use hf_overrides to add is_decoder flag to maintain similar semantics as HF models. If is_decoder=True (assumed as default for Qwen2), then we use decoder-only mask, otherwise we use the encoder-decoder mask.

@Ecocytus
Copy link

Is there any example usage of the Qwen2 embedding? The embeddings from hf and vllm cannot match with the official usage doc

After some debugging, I found the problem - you need to set trust_remote_code=True when loading the model in vLLM.

This works for me. Thanks a lot!

@mgoin
Copy link
Member

mgoin commented Nov 25, 2024

Using hf_overrides sounds like the best path, nice find!

@AnthonyX1an
Copy link

If I load the model Alibaba-NLP/gte-Qwen2-1.5B-instruct locally, can I set this configuration trust_remote_code=True

@DarkLight1337
Copy link
Member Author

If I load the model Alibaba-NLP/gte-Qwen2-1.5B-instruct locally, can I set this configuration trust_remote_code=True

Yes.

@AnthonyX1an
Copy link

Hello! I am a beginner in using LLMs, and I have a question. If I want to obtain the output of the last hidden layer of the qwen2-1.5B-instruct model as an embedding, can I use LLM.encode()? Thanks for your patience!

@DarkLight1337
Copy link
Member Author

DarkLight1337 commented Dec 25, 2024

Hello! I am a beginner in using LLMs, and I have a question. If I want to obtain the output of the last hidden layer of the qwen2-1.5B-instruct model as an embedding, can I use LLM.encode()? Thanks for your patience!

Yes, but you have to set --task embed --override-pooler-config '{"pooling_type": "ALL", "normalize": false}' (or something like that; I wrote this off the top of my head) to return all of the hidden states.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
7 participants