Skip to content

Commit

Permalink
[Model]: add some tests for aria model (vllm-project#10770)
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
xffxff and Isotr0py authored Dec 2, 2024
1 parent 995a148 commit ef31eab
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
6 changes: 5 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,7 @@ def __init__(
model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
Expand All @@ -672,6 +673,7 @@ def __init__(
model=model_name,
task=task,
tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
Expand Down Expand Up @@ -842,14 +844,16 @@ def generate_greedy_logprobs(
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids)
stop_token_ids=stop_token_ids,
stop=stop)

return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
Expand Down
30 changes: 30 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import transformers
from transformers import AutoModelForVision2Seq
from transformers.utils import is_flash_attn_2_available

from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, identity
Expand Down Expand Up @@ -134,6 +135,35 @@
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
tokenizer_mode="slow",
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
),
dtype="bfloat16",
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
max_model_len=4096,
max_num_seqs=2,
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<vlm_image>Please describe the image shortly.",
"cherry_blossom": "<vlm_image>Please infer the season with reason.",
}),
multi_image_prompt="<vlm_image><vlm_image>Describe the two images shortly.", # noqa: E501
postprocess_inputs=model_utils.get_key_type_post_processor("pixel_values"),
stop_str=["<|im_end|>"],
image_size_factors=[(0.10, 0.15)],
max_tokens=64,
marks=[
pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="Model needs flash-attn for numeric convergence.",
),
large_gpu_mark(min_gb=64),
],
),
"blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"],
test_type=VLMTestType.IMAGE,
Expand Down
11 changes: 9 additions & 2 deletions tests/models/decoder_only/vision_language/vlm_utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def run_test(
postprocess_inputs: Callable[[BatchEncoding], BatchEncoding],
comparator: Callable[..., None],
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]],
stop_str: Optional[List[str]],
tokenizer_mode: str,
limit_mm_per_prompt: Dict[str, int],
model_kwargs: Optional[Dict[str, Any]],
patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]],
Expand All @@ -50,11 +52,14 @@ def run_test(
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
vllm_kwargs = {}
vllm_kwargs: Dict[str, Any] = {}
if get_stop_token_ids is not None:
vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer)
if stop_str:
vllm_kwargs["stop"] = stop_str

with vllm_runner(model,
tokenizer_mode=tokenizer_mode,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
dtype=dtype,
Expand Down Expand Up @@ -85,6 +90,8 @@ def run_test(
hf_kwargs = {}
if use_tokenizer_eos:
hf_kwargs["eos_token_id"] = tokenizer.eos_token_id
if stop_str:
hf_kwargs["stop_strings"] = stop_str

with hf_model, torch.no_grad():
for prompts, media in inputs:
Expand Down Expand Up @@ -138,4 +145,4 @@ def process_runner_outputs(
def process_outputs(output_processor, model, outputs_per_image):
"""Applies a model specific post-processor function to a runner's output"""
return [[output_processor(res, model) for res in outputs]
for outputs in outputs_per_image]
for outputs in outputs_per_image]
7 changes: 7 additions & 0 deletions tests/models/decoder_only/vision_language/vlm_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class VLMTestInfo(NamedTuple):

# Optional callable which gets a list of token IDs from the model tokenizer
get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None
# Optional list of strings to stop generation, useful when stop tokens are
# not special tokens in the tokenizer
stop_str: Optional[List[str]] = None

# Exposed options for HF runner
model_kwargs: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -148,6 +151,8 @@ class VLMTestInfo(NamedTuple):

marks: Optional[List[MarkDecorator]] = None

tokenizer_mode: str = "auto"

def get_non_parametrized_runner_kwargs(self):
"""Returns a dictionary of expandable kwargs for items that are used
in all test types, which are NOT used when creating the parametrized
Expand All @@ -166,8 +171,10 @@ def get_non_parametrized_runner_kwargs(self):
"postprocess_inputs": self.postprocess_inputs,
"comparator": self.comparator,
"get_stop_token_ids": self.get_stop_token_ids,
"stop_str": self.stop_str,
"model_kwargs": self.model_kwargs,
"patch_hf_runner": self.patch_hf_runner,
"tokenizer_mode": self.tokenizer_mode
}


Expand Down

0 comments on commit ef31eab

Please sign in to comment.