From 9611f9d62c24cfdb674071e3238ba8f4bd9837db Mon Sep 17 00:00:00 2001 From: aria-hacker Date: Mon, 30 Sep 2024 09:48:38 +0800 Subject: [PATCH 1/2] Add support for split image, and rename image_max_size to max_image_size --- aria/config.py | 4 + aria/data.py | 5 +- aria/inference.py | 22 ++++- aria/model/processing_aria.py | 35 +++++--- aria/model/vision_processor.py | 125 ++++++++++++++++++++++++++-- aria/train.py | 14 ++-- aria/vllm/aria.py | 132 ++++++++++++++++++++++++++---- examples/nextqa/evaluation.py | 2 +- examples/nlvr2/evaluation.py | 2 +- examples/refcoco/evaluation.py | 2 +- tests/test_apply_chat_template.py | 16 +++- 11 files changed, 311 insertions(+), 48 deletions(-) diff --git a/aria/config.py b/aria/config.py index 1e59a5a..4c0a141 100644 --- a/aria/config.py +++ b/aria/config.py @@ -65,6 +65,10 @@ class AriaModelConfig(ModelConfig): "choices": [490, 980], }, ) + split_image: bool = field( + default=False, + metadata={"help": "Whether to split the image into smaller patches."}, + ) def __post_init__(self): super().__post_init__() diff --git a/aria/data.py b/aria/data.py index 05ecd5f..1b7459f 100644 --- a/aria/data.py +++ b/aria/data.py @@ -19,7 +19,7 @@ import os import warnings -from typing import Dict, List +from typing import Dict, Iterable, List import torch from datasets import DatasetDict, concatenate_datasets, load_dataset @@ -29,6 +29,7 @@ def apply_chat_template_and_tokenize( messages_batch: List[List[Dict]], tokenizer, + num_image_crop: Iterable[torch.Tensor] = iter([]), ): IGNORE_TOKEN_ID = -100 im_start_tokens = tokenizer("<|im_start|>").input_ids @@ -41,7 +42,7 @@ def process_content(content): if content["type"] == "text": return content["text"] elif content["type"] == "image": - return "<|img|>" + return "" + "<|img|>" * next(num_image_crop) + "" else: raise ValueError(f"Unknown content type {content['type']} in message") diff --git a/aria/inference.py b/aria/inference.py index 419af54..24ca67c 100644 --- a/aria/inference.py +++ b/aria/inference.py @@ -42,6 +42,13 @@ def parse_arguments(): help="Maximum size of the image to be processed", default=980, ) + parser.add_argument( + "--split_image", + type=bool, + help="Whether to split the image into patches", + action="store_true", + default=False, + ) return parser.parse_args() @@ -65,7 +72,9 @@ def load_model(base_model_path, peft_model_path=None): return model -def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size): +def prepare_input( + image_path, prompt, processor: AriaProcessor, max_image_size, split_image +): image = Image.open(image_path) messages = [ @@ -85,6 +94,7 @@ def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size): images=image, return_tensors="pt", max_image_size=max_image_size, + split_image=split_image, ) return inputs @@ -96,8 +106,9 @@ def inference( model: AriaForConditionalGeneration, processor: AriaProcessor, max_image_size, + split_image, ): - inputs = prepare_input(image_path, prompt, processor, max_image_size) + inputs = prepare_input(image_path, prompt, processor, max_image_size, split_image) inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) inputs = {k: v.to(model.device) for k, v in inputs.items()} @@ -124,7 +135,12 @@ def main(): model = load_model(args.base_model_path, args.peft_model_path) result = inference( - args.image_path, args.prompt, model, processor, args.max_image_size + args.image_path, + args.prompt, + model, + processor, + args.max_image_size, + args.split_image, ) print(result) diff --git a/aria/model/processing_aria.py b/aria/model/processing_aria.py index 9c4922d..f02f08c 100644 --- a/aria/model/processing_aria.py +++ b/aria/model/processing_aria.py @@ -18,6 +18,7 @@ # under the License. import inspect +import re from typing import List, Optional, Union from transformers import AutoTokenizer, BatchFeature @@ -61,7 +62,7 @@ def __init__( super().__init__(chat_template=chat_template) if image_processor is None: - self.image_processor = AriaVisionProcessor(image_max_size=patch_size) + self.image_processor = AriaVisionProcessor(max_image_size=patch_size) else: self.image_processor = image_processor @@ -87,6 +88,7 @@ def __call__( truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, max_image_size: Optional[int] = 980, + split_image: Optional[bool] = False, return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, ) -> BatchFeature: """ @@ -114,6 +116,8 @@ def __call__( Maximum length of the returned list and optionally padding length (see above). max_image_size (`int`, *optional*): Maximum size of the image to be processed. + split_image (`bool`, *optional*): + Whether to split the image into patches before processing. truncation (`bool`, *optional*): Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): @@ -134,24 +138,35 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`. """ + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) + if images is not None: image_inputs = self.image_processor( images, return_tensors=return_tensors, max_image_size=max_image_size, + split_image=split_image, ) + # expand the image_token according to the num_crops of image + prompt_strings = [] + crop_iter = iter(image_inputs.pop("num_crops")) + for prompt in text: + prompt_strings.append( + re.sub( + re.escape(self.image_token), + lambda _: next(crop_iter) * self.image_token, + prompt, + ) + ) + else: image_inputs = {} - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError( - "Invalid input text. Please provide a string, or a list of strings" - ) - - prompt_strings = text - text_inputs = self.tokenizer( prompt_strings, return_tensors=return_tensors, diff --git a/aria/model/vision_processor.py b/aria/model/vision_processor.py index fe263dd..adf5eca 100644 --- a/aria/model/vision_processor.py +++ b/aria/model/vision_processor.py @@ -19,12 +19,93 @@ from typing import List, Optional, Union +import numpy as np import torch from PIL import Image, ImageOps from torchvision import transforms from transformers import BaseImageProcessor, BatchFeature, TensorType +def _select_best_resolution( + img_width: int, img_height: int, target_ratios: List[List[int]], patch_size: int +): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + img_width: the original widths of images. + img_height: the original heights of images. + target_ratios (2d numpy array): dimension size (M,2) + patch_size (int): image patch size + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + + aspect_ratio = img_width / img_height + best_ratio_diff = float("inf") + best_ratio_w, best_ratio_h = 1, 1 + area = np.int32(img_height) * np.int32(img_height) + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio_w, best_ratio_h = ratio[0], ratio[1] + elif ( + ratio_diff == best_ratio_diff + and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1] + ): + best_ratio_w, best_ratio_h = ratio[0], ratio[1] + + return best_ratio_w, best_ratio_h + + +def _split_image( + image: Image.Image, + split_image: bool, + split_ratio: List[List[int]], + patch_size: int, +) -> List[Image.Image]: + """ + Split image into multiple patches + + Args: + image (PIL.Image): Input image. + split_image (bool): Whether to split the image into patches. + split_ratio (2d numpy array): dimension size (M,2) + patch_size (int): image patch size + + Returns: + List[PIL.Image]: List of splitted images. + """ + if split_image: + ratio_width, ratio_height = _select_best_resolution( + image.width, image.height, split_ratio, patch_size + ) + resize_width = patch_size * ratio_width + resize_height = patch_size * ratio_height + blocks = ratio_width * ratio_height + resized_img = image.resize((resize_width, resize_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (resize_width // patch_size)) * patch_size, + (i // (resize_width // patch_size)) * patch_size, + ((i % (resize_width // patch_size)) + 1) * patch_size, + ((i // (resize_width // patch_size)) + 1) * patch_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if len(processed_images) != 1: + processed_images.insert(0, image) + return processed_images + else: + return [image] + + def keep_ratio_resize_and_pixel_mask( img: Image.Image, max_size, min_size=336, padding_value=0 ): @@ -122,6 +203,17 @@ def __call__( max_image_size: Optional[int] = 980, min_image_size: Optional[int] = 336, return_tensors: Optional[Union[str, TensorType]] = "pt", + split_image: Optional[bool] = False, + split_ratio: Optional[List[List[int]]] = [ + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [2, 2], + [2, 1], + [3, 1], + [4, 1], + ], ): """ Process a list of images. @@ -130,6 +222,8 @@ def __call__( images (list): List of PIL.Image objects. max_image_size (int, optional): Override the default max image size. Defaults to None. return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt". + split_image (bool, optional): Whether to split the image. Defaults to False. + split_ratio (list, optional): The ratio for splitting the image. Defaults to a list of common split ratios. Returns: BatchFeature: A BatchFeature object containing: - 'pixel_values': Tensor of processed image pixel values. @@ -137,6 +231,7 @@ def __call__( - True (1) values indicate pixels that belong to the original resized image. - False (0) values indicate pixels that are part of the padding. The mask helps distinguish between actual image content and padded areas in subsequent processing steps. + - 'num_crops': Tensor of the number of crops for each image. """ max_size = self.max_image_size if max_image_size is None else max_image_size min_size = self.min_image_size if min_image_size is None else min_image_size @@ -149,19 +244,24 @@ def __call__( pixel_values = [] pixel_masks = [] + num_crops = [] for image in images: - img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask( - image, max_size, min_size - ) - img_padded = self.transform(img_padded) - pixel_values.append(img_padded) - pixel_masks.append(pixel_mask) + crop_images = _split_image(image, split_image, split_ratio, max_size) + num_crops.append(torch.tensor(len(crop_images))) + for crop_image in crop_images: + img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask( + crop_image, max_size, min_size + ) + img_padded = self.transform(img_padded) + pixel_values.append(img_padded) + pixel_masks.append(pixel_mask) return BatchFeature( data={ "pixel_values": torch.stack(pixel_values), "pixel_mask": torch.stack(pixel_masks), + "num_crops": torch.stack(num_crops), }, tensor_type=return_tensors, ) @@ -172,10 +272,23 @@ def preprocess( max_image_size=None, min_image_size=None, return_tensors: Optional[Union[str, TensorType]] = None, + split_image: Optional[bool] = False, + split_ratio: Optional[List[List[int]]] = [ + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [2, 2], + [2, 1], + [3, 1], + [4, 1], + ], ): return self.__call__( images, max_image_size=max_image_size, min_image_size=min_image_size, return_tensors=return_tensors, + split_image=split_image, + split_ratio=split_ratio, ) diff --git a/aria/train.py b/aria/train.py index b4972a2..6dad4bf 100644 --- a/aria/train.py +++ b/aria/train.py @@ -83,7 +83,7 @@ def setup_model_and_tokenizer(model_config): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.unk_token - processor = AriaVisionProcessor(image_max_size=model_config.max_image_size) + processor = AriaVisionProcessor(max_image_size=model_config.max_image_size) return model, tokenizer, processor @@ -114,7 +114,7 @@ def setup_peft(model, model_config): return model -def collate_fn(examples, tokenizer, processor): +def collate_fn(examples, tokenizer, processor, split_image: bool = False): images = [] messages = [] for example in examples: @@ -178,13 +178,15 @@ def collate_fn(examples, tokenizer, processor): Image.open(image).convert("RGB") if isinstance(image, str) else image for image in images ] + image_inputs = processor(images, split_image=split_image) batch = apply_chat_template_and_tokenize( messages, tokenizer, + iter(image_inputs.pop("num_crops")), ) - images = processor(images) - batch.update(images) + + batch.update(image_inputs) batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16) @@ -213,7 +215,9 @@ def main(): trainer = SFTTrainer( model=model, args=training_args, - data_collator=lambda examples: collate_fn(examples, tokenizer, processor), + data_collator=lambda examples: collate_fn( + examples, tokenizer, processor, model_config.split_image + ), train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, diff --git a/aria/vllm/aria.py b/aria/vllm/aria.py index 08972a2..25e13b7 100644 --- a/aria/vllm/aria.py +++ b/aria/vllm/aria.py @@ -20,11 +20,11 @@ import os from typing import Iterable, List, Optional, Tuple +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image -from torch import nn from transformers import LlamaConfig from transformers.utils import logging from vllm.attention import AttentionMetadata @@ -863,6 +863,95 @@ def build_mm_projector(config: AriaConfig): ) +def _select_best_resolution( + img_width: int, img_height: int, target_ratios: List[List[int]], patch_size: int +): + """ + Selects the best resolution from a list of possible resolutions based on the original size. + + Args: + img_width: the original widths of images. + img_height: the original heights of images. + target_ratios (2d numpy array): dimension size (M,2) + patch_size (int): image patch size + + Returns: + tuple: The best fit resolution in the format (width, height). + """ + + aspect_ratio = img_width / img_height + best_ratio_diff = float("inf") + best_ratio_w, best_ratio_h = 1, 1 + area = np.int32(img_height) * np.int32(img_height) + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio_w, best_ratio_h = ratio[0], ratio[1] + elif ( + ratio_diff == best_ratio_diff + and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1] + ): + best_ratio_w, best_ratio_h = ratio[0], ratio[1] + + return best_ratio_w, best_ratio_h + + +def split_image( + image: Image.Image, + split_image: bool, + split_ratio: List[List[int]] = [ + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [2, 2], + [2, 1], + [3, 1], + [4, 1], + ], + patch_size: int = 980, +) -> List[Image.Image]: + """ + Split image into multiple patches + + Args: + image (PIL.Image): Input image. + split_image (bool): Whether to split the image into patches. + split_ratio (2d numpy array): dimension size (M,2) + patch_size (int): image patch size + + Returns: + List[PIL.Image]: List of splitted images. + """ + if split_image: + ratio_width, ratio_height = _select_best_resolution( + image.width, image.height, split_ratio, patch_size + ) + resize_width = patch_size * ratio_width + resize_height = patch_size * ratio_height + blocks = ratio_width * ratio_height + resized_img = image.resize((resize_width, resize_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (resize_width // patch_size)) * patch_size, + (i // (resize_width // patch_size)) * patch_size, + ((i % (resize_width // patch_size)) + 1) * patch_size, + ((i // (resize_width // patch_size)) + 1) * patch_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if len(processed_images) != 1: + processed_images.insert(0, image) + return processed_images + else: + return [image] + + def get_max_multimodal_tokens(ctx): return max(ctx.model_config.hf_config.image_size2tokens.values()) @@ -876,7 +965,7 @@ def input_mapper_for_aria(ctx, data): The only different is we would like to support runtime max_image_size adjustment. """ model_config = ctx.model_config - image_max_size = getattr(model_config.multimodal_config, "max_image_size", 980) + max_image_size = getattr(model_config.multimodal_config, "max_image_size", 980) # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): @@ -889,8 +978,9 @@ def input_mapper_for_aria(ctx, data): ) try: batch_data = image_processor.preprocess( - data, image_max_size=image_max_size, return_tensors="pt" + data, max_image_size=max_image_size, return_tensors="pt" ).data + batch_data.pop("num_crops") except Exception: logger.error("Failed to process image (%s)", data) raise @@ -916,26 +1006,38 @@ def input_processor(ctx, llm_inputs): hf_config = model_config.hf_config # prepare image tokens, the max_image_size is used to determine the number of patch_size for every image - image_max_size = multi_modal_data.pop("max_image_size", 980) - if isinstance(image_max_size, int) or isinstance(image_max_size, float): - num_images = ( - len(multi_modal_data["image"]) - if isinstance(multi_modal_data["image"], list) - else 1 - ) - image_max_size = [image_max_size] * num_images + max_image_size = multi_modal_data.pop("max_image_size", 980) + _split_image = multi_modal_data.pop("split_image", False) + + assert isinstance(max_image_size, int) or isinstance( + max_image_size, float + ), "max_image_size should be float or int" + images = ( + multi_modal_data["image"] + if isinstance(multi_modal_data["image"], list) + else [multi_modal_data["image"]] + ) + num_crops = [] + splitted_images = [] + for image in images: + splitted_image = split_image(image, _split_image, patch_size=max_image_size) + splitted_images.extend(splitted_image) + num_crops.append(len(splitted_image)) + max_image_size = [max_image_size] * len(images) + # reassign the image because we might split them into mini-patches + multi_modal_data["image"] = splitted_images # Mapping the image patch size to the corresponding number of tokens for each image image_feature_sizes = [] - for image_size in image_max_size: + for image_size, num_crop in zip(max_image_size, num_crops): assert ( image_size in hf_config.image_size2tokens ), f"Invalid image size: {image_size}, available options: {list(hf_config.image_size2tokens.keys())}" - image_feature_sizes.append(hf_config.image_size2tokens[image_size]) + image_feature_sizes.append(hf_config.image_size2tokens[image_size] * num_crop) - # Set up the image_max_size in the RuntimeContext for the image processor + # Set up the max_image_size and split_image in the RuntimeContext for the image processor # TODO: Supports dynamic image size support - setattr(model_config.multimodal_config, "max_image_size", max(image_max_size)) + setattr(model_config.multimodal_config, "max_image_size", max(max_image_size)) new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, diff --git a/examples/nextqa/evaluation.py b/examples/nextqa/evaluation.py index e69e9ac..2c3a01f 100644 --- a/examples/nextqa/evaluation.py +++ b/examples/nextqa/evaluation.py @@ -60,7 +60,7 @@ def __getitem__(self, idx): def load_model_and_tokenizer(args): - processor = AriaVisionProcessor(image_max_size=args.image_size) + processor = AriaVisionProcessor(max_image_size=args.image_size) tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_path, use_fast=False, padding_side="left" ) diff --git a/examples/nlvr2/evaluation.py b/examples/nlvr2/evaluation.py index a7622e8..637df50 100644 --- a/examples/nlvr2/evaluation.py +++ b/examples/nlvr2/evaluation.py @@ -62,7 +62,7 @@ def __getitem__(self, idx): def load_model_and_tokenizer(args): - processor = AriaVisionProcessor(image_max_size=args.image_size) + processor = AriaVisionProcessor(max_image_size=args.image_size) tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_path, use_fast=False, padding_side="left" ) diff --git a/examples/refcoco/evaluation.py b/examples/refcoco/evaluation.py index a94e842..682fa46 100644 --- a/examples/refcoco/evaluation.py +++ b/examples/refcoco/evaluation.py @@ -64,7 +64,7 @@ def __getitem__(self, idx): def load_model_and_tokenizer(args): - processor = AriaVisionProcessor(image_max_size=args.image_size) + processor = AriaVisionProcessor(max_image_size=args.image_size) tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_path, use_fast=False, padding_side="left" ) diff --git a/tests/test_apply_chat_template.py b/tests/test_apply_chat_template.py index 6ff48e0..b948d59 100644 --- a/tests/test_apply_chat_template.py +++ b/tests/test_apply_chat_template.py @@ -26,7 +26,9 @@ def test_apply_chat_template_single_user_message(tokenizer): } ] expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n" - res = apply_chat_template_and_tokenize([messages], tokenizer=tokenizer) + res = apply_chat_template_and_tokenize( + [messages], num_image_crop=iter([1]), tokenizer=tokenizer + ) input_ids = res["input_ids"] input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] assert input_str == expected_output @@ -64,7 +66,9 @@ def test_apply_chat_template_multiple_messages(tokenizer): }, ] expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n" - res = apply_chat_template_and_tokenize([messages], tokenizer=tokenizer) + res = apply_chat_template_and_tokenize( + [messages], num_image_crop=iter([1]), tokenizer=tokenizer + ) input_ids = res["input_ids"] input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] assert input_str == expected_output @@ -118,7 +122,9 @@ def test_apply_chat_template_multi_round_messages(tokenizer): }, ] expected_output = "<|im_start|>user\nWho wrote this book?\n<|img|><|im_end|>\n<|im_start|>assistant\nSylvie Covey<|im_end|>\n<|im_start|>user\nWhat is the title of this book?<|im_end|>\n<|im_start|>assistant\nModern Printmaking: A Guide to Traditional and Digital Techniques<|im_end|>\n" - res = apply_chat_template_and_tokenize([messages], tokenizer=tokenizer) + res = apply_chat_template_and_tokenize( + [messages], num_image_crop=iter([1]), tokenizer=tokenizer + ) input_ids = res["input_ids"] input_str = tokenizer.batch_decode(input_ids, skip_special_tokens=True)[0] assert input_str == expected_output @@ -172,7 +178,9 @@ def test_apply_chat_template_batch_messages(tokenizer): ], ] - res = apply_chat_template_and_tokenize(messages_batch, tokenizer=tokenizer) + res = apply_chat_template_and_tokenize( + messages_batch, num_image_crop=iter([1, 1]), tokenizer=tokenizer + ) input_ids = res["input_ids"] expected_output = [ From 8f203a88236b69f4772e81af36fe79467c3609ad Mon Sep 17 00:00:00 2001 From: aria-hacker Date: Mon, 30 Sep 2024 11:12:57 +0800 Subject: [PATCH 2/2] refactor the vllm inference.md --- docs/inference.md | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/docs/inference.md b/docs/inference.md index 0fbabc2..9b1de1d 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -85,6 +85,7 @@ pip install -e .[vllm] ### How to Use: ```python from PIL import Image +from transformers import AutoTokenizer from vllm import LLM, ModelRegistry, SamplingParams from vllm.model_executor.models import _MULTIMODAL_MODELS @@ -102,33 +103,52 @@ _MULTIMODAL_MODELS["AriaForConditionalGeneration"] = ( def main(): llm = LLM( model="rhymes-ai/Aria", - tokenizer="rhymes-ai/Aria", dtype="bfloat16", limit_mm_per_prompt={"image": 256}, enforce_eager=True, - tokenizer_mode="slow", trust_remote_code=True, + skip_tokenizer_init=True, ) - prompt = "Question: Compare Image 1 and image 2, tell me about the differences between image 1 and image 2.\nImage 1\n<|img|>\nImage 2\n<|img|> Answer: " + tokenizer = AutoTokenizer.from_pretrained( + "rhymes-ai/Aria", trust_remote_code=True, use_fast=False + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Compare Image 1 and image 2, tell me about the differences between image 1 and image 2.\nImage 1\n", + }, + {"type": "image"}, + {"type": "text", "text": "\nImage 2\n"}, + {"type": "image"}, + ], + } + ] + + message = tokenizer.apply_chat_template(messages, add_generation_prompt=True) outputs = llm.generate( { - "prompt": prompt, + "prompt_token_ids": message, "multi_modal_data": { "image": [ Image.open("assets/princess1.jpg"), Image.open("assets/princess2.jpg"), ], - "max_image_size": 980, + "max_image_size": 980, # [Optional] The max image patch size, default `980` + "split_image": True, # [Optional] whether to split the images, default `False` }, }, sampling_params=SamplingParams(max_tokens=200, top_k=1), ) for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) + generated_tokens = o.outputs[0].token_ids + print(tokenizer.decode(generated_tokens)) if __name__ == "__main__":