From 7cd7409142ff97aee1a13568753db9263fcf8f6b Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Fri, 13 Dec 2024 09:40:07 +0200 Subject: [PATCH] PaliGemma 2 support (#11142) --- docs/source/models/supported_models.rst | 4 ++-- examples/offline_inference_vision_language.py | 13 +++++++++++++ vllm/model_executor/models/paligemma.py | 11 ++++++++++- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 35aa3bfdd12b7..cae4a88de1638 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -664,9 +664,9 @@ Text Generation (``--task generate``) - ✅︎ - ✅︎ * - :code:`PaliGemmaForConditionalGeneration` - - PaliGemma + - PaliGemma, PaliGemma 2 - T + I\ :sup:`E` - - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. + - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, :code:`google/paligemma2-3b-ft-docci-448`, etc. - - ✅︎ - diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 5e210126dc8fe..c430f42fdc814 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -137,6 +137,18 @@ def run_paligemma(question: str, modality: str): return llm, prompt, stop_token_ids +# PaliGemma 2 +def run_paligemma2(question: str, modality: str): + assert modality == "image" + + # PaliGemma 2 has special prompt format for VQA + prompt = "caption en" + llm = LLM(model="google/paligemma2-3b-ft-docci-448", + mm_cache_preprocessor=args.mm_cache_preprocessor) + stop_token_ids = None + return llm, prompt, stop_token_ids + + # Chameleon def run_chameleon(question: str, modality: str): assert modality == "image" @@ -473,6 +485,7 @@ def run_mantis(question: str, modality: str): "fuyu": run_fuyu, "phi3_v": run_phi3v, "paligemma": run_paligemma, + "paligemma2": run_paligemma2, "chameleon": run_chameleon, "minicpmv": run_minicpmv, "blip-2": run_blip2, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 253e689e50a3b..f9ad0c67adaba 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -105,6 +105,11 @@ def input_processor_for_paligemma(ctx: InputContext, orig_prompt_ids.remove(hf_config.image_token_index) new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" + + # The PaliGemma 2 tokenizer does not include a starting BOS token + if orig_prompt_ids[0] != hf_config.bos_token_id: + orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids + new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline # NOTE: Create a defensive copy of the original inputs @@ -149,7 +154,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): projection_dim=config.vision_config.projection_dim) self.quant_config = quant_config - config.text_config.architectures = ["GemmaForCausalLM"] + + if config.text_config.model_type == "gemma": + config.text_config.architectures = ["GemmaForCausalLM"] + else: + config.text_config.architectures = ["Gemma2ForCausalLM"] self.language_model = init_vllm_registered_model( vllm_config=vllm_config, hf_config=config.text_config,