Skip to content

Commit

Permalink
Add support for Phi-3-Vision auto-captioning model
Browse files Browse the repository at this point in the history
  • Loading branch information
jhc13 committed Jul 28, 2024
1 parent 7899e96 commit 9eae099
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ torch==2.2.2; platform_system != "Windows"
https://download.pytorch.org/whl/cu121/torch-2.2.2%2Bcu121-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://download.pytorch.org/whl/cu121/torch-2.2.2%2Bcu121-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"

# FlashAttention (Florence-2)
# FlashAttention (Florence-2, Phi-3-Vision)
flash-attn==2.6.3; platform_system == "Linux"
https://github.com/oobabooga/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu122torch2.2.2cxx11abiFALSE-cp311-cp311-win_amd64.whl; platform_system == "Windows" and python_version == "3.11"
https://github.com/oobabooga/flash-attention/releases/download/v2.6.1/flash_attn-2.6.1+cu122torch2.2.2cxx11abiFALSE-cp310-cp310-win_amd64.whl; platform_system == "Windows" and python_version == "3.10"
Expand Down
42 changes: 42 additions & 0 deletions taggui/auto_captioning/models/phi_3_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from transformers import AutoModelForCausalLM, BatchFeature

import auto_captioning.captioning_thread as captioning_thread
from auto_captioning.auto_captioning_model import AutoCaptioningModel
from utils.image import Image


class Phi3Vision(AutoCaptioningModel):
transformers_model_class = AutoModelForCausalLM

def __init__(self,
captioning_thread_: 'captioning_thread.CaptioningThread',
caption_settings: dict):
super().__init__(captioning_thread_, caption_settings)
self.input_length = None

@staticmethod
def get_default_prompt() -> str:
return 'Describe the image in one sentence.'

@staticmethod
def format_prompt(prompt: str) -> str:
return f'<|user|>\n<|image_1|>\n{prompt}<|end|>\n<|assistant|>\n'

def get_input_text(self, image_prompt: str) -> str:
return image_prompt + self.caption_start

def get_model_inputs(self, image_prompt: str,
image: Image) -> BatchFeature:
model_inputs = super().get_model_inputs(image_prompt, image)
self.input_length = model_inputs['input_ids'].shape[1]
return model_inputs

def get_additional_generation_parameters(self) -> dict:
return {'eos_token_id': self.tokenizer.eos_token_id}

def get_caption_from_generated_tokens(
self, generated_token_ids: torch.Tensor, image_prompt: str) -> str:
generated_token_ids = generated_token_ids[:, self.input_length:]
return super().get_caption_from_generated_tokens(
generated_token_ids, image_prompt)
4 changes: 4 additions & 0 deletions taggui/auto_captioning/models_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from auto_captioning.models.llava_next import (LlavaNext34b, LlavaNextMistral,
LlavaNextVicuna)
from auto_captioning.models.moondream import Moondream1, Moondream2
from auto_captioning.models.phi_3_vision import Phi3Vision
from auto_captioning.models.wd_tagger import WdTagger
from auto_captioning.models.xcomposer2 import Xcomposer2, Xcomposer2_4khd

Expand All @@ -25,6 +26,7 @@
'microsoft/Florence-2-base-ft',
'microsoft/Florence-2-base',
'MiaoshouAI/Florence-2-base-PromptGen',
'microsoft/Phi-3-vision-128k-instruct',
'llava-hf/llava-v1.6-mistral-7b-hf',
'llava-hf/llava-v1.6-vicuna-7b-hf',
'llava-hf/llava-v1.6-vicuna-13b-hf',
Expand Down Expand Up @@ -85,6 +87,8 @@ def get_model_class(model_id: str) -> type[AutoCaptioningModel]:
return Moondream1
if 'moondream2' in lowercase_model_id:
return Moondream2
if 'phi-3' in lowercase_model_id:
return Phi3Vision
if 'wd' in lowercase_model_id and 'tagger' in lowercase_model_id:
return WdTagger
if 'xcomposer2' in lowercase_model_id:
Expand Down

0 comments on commit 9eae099

Please sign in to comment.