Skip to content

Commit

Permalink
Merge pull request #2 from rhymes-ai/split_image
Browse files Browse the repository at this point in the history
support spliting image
  • Loading branch information
aria-hacker authored Sep 30, 2024
2 parents b8e0a93 + 8f203a8 commit 4d70b52
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 55 deletions.
4 changes: 4 additions & 0 deletions aria/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class AriaModelConfig(ModelConfig):
"choices": [490, 980],
},
)
split_image: bool = field(
default=False,
metadata={"help": "Whether to split the image into smaller patches."},
)

def __post_init__(self):
super().__post_init__()
Expand Down
5 changes: 3 additions & 2 deletions aria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import os
import warnings
from typing import Dict, List
from typing import Dict, Iterable, List

import torch
from datasets import DatasetDict, concatenate_datasets, load_dataset
Expand All @@ -29,6 +29,7 @@
def apply_chat_template_and_tokenize(
messages_batch: List[List[Dict]],
tokenizer,
num_image_crop: Iterable[torch.Tensor] = iter([]),
):
IGNORE_TOKEN_ID = -100
im_start_tokens = tokenizer("<|im_start|>").input_ids
Expand All @@ -41,7 +42,7 @@ def process_content(content):
if content["type"] == "text":
return content["text"]
elif content["type"] == "image":
return "<fim_prefix><|img|><fim_suffix>"
return "<fim_prefix>" + "<|img|>" * next(num_image_crop) + "<fim_suffix>"
else:
raise ValueError(f"Unknown content type {content['type']} in message")

Expand Down
22 changes: 19 additions & 3 deletions aria/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def parse_arguments():
help="Maximum size of the image to be processed",
default=980,
)
parser.add_argument(
"--split_image",
type=bool,
help="Whether to split the image into patches",
action="store_true",
default=False,
)
return parser.parse_args()


Expand All @@ -65,7 +72,9 @@ def load_model(base_model_path, peft_model_path=None):
return model


def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size):
def prepare_input(
image_path, prompt, processor: AriaProcessor, max_image_size, split_image
):
image = Image.open(image_path)

messages = [
Expand All @@ -85,6 +94,7 @@ def prepare_input(image_path, prompt, processor: AriaProcessor, max_image_size):
images=image,
return_tensors="pt",
max_image_size=max_image_size,
split_image=split_image,
)

return inputs
Expand All @@ -96,8 +106,9 @@ def inference(
model: AriaForConditionalGeneration,
processor: AriaProcessor,
max_image_size,
split_image,
):
inputs = prepare_input(image_path, prompt, processor, max_image_size)
inputs = prepare_input(image_path, prompt, processor, max_image_size, split_image)
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}

Expand Down Expand Up @@ -126,7 +137,12 @@ def main():
model = load_model(args.base_model_path, args.peft_model_path)

result = inference(
args.image_path, args.prompt, model, processor, args.max_image_size
args.image_path,
args.prompt,
model,
processor,
args.max_image_size,
args.split_image,
)
print(result)

Expand Down
35 changes: 25 additions & 10 deletions aria/model/processing_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# under the License.

import inspect
import re
from typing import List, Optional, Union

from transformers import AutoTokenizer, BatchFeature
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
super().__init__(chat_template=chat_template)

if image_processor is None:
self.image_processor = AriaVisionProcessor(image_max_size=patch_size)
self.image_processor = AriaVisionProcessor(max_image_size=patch_size)
else:
self.image_processor = image_processor

Expand All @@ -87,6 +88,7 @@ def __call__(
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
max_image_size: Optional[int] = 980,
split_image: Optional[bool] = False,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -114,6 +116,8 @@ def __call__(
Maximum length of the returned list and optionally padding length (see above).
max_image_size (`int`, *optional*):
Maximum size of the image to be processed.
split_image (`bool`, *optional*):
Whether to split the image into patches before processing.
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
Expand All @@ -134,24 +138,35 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
"""
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)

if images is not None:
image_inputs = self.image_processor(
images,
return_tensors=return_tensors,
max_image_size=max_image_size,
split_image=split_image,
)
# expand the image_token according to the num_crops of image
prompt_strings = []
crop_iter = iter(image_inputs.pop("num_crops"))
for prompt in text:
prompt_strings.append(
re.sub(
re.escape(self.image_token),
lambda _: next(crop_iter) * self.image_token,
prompt,
)
)

else:
image_inputs = {}

if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError(
"Invalid input text. Please provide a string, or a list of strings"
)

prompt_strings = text

text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
Expand Down
125 changes: 119 additions & 6 deletions aria/model/vision_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,93 @@

from typing import List, Optional, Union

import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision import transforms
from transformers import BaseImageProcessor, BatchFeature, TensorType


def _select_best_resolution(
img_width: int, img_height: int, target_ratios: List[List[int]], patch_size: int
):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
img_width: the original widths of images.
img_height: the original heights of images.
target_ratios (2d numpy array): dimension size (M,2)
patch_size (int): image patch size
Returns:
tuple: The best fit resolution in the format (width, height).
"""

aspect_ratio = img_width / img_height
best_ratio_diff = float("inf")
best_ratio_w, best_ratio_h = 1, 1
area = np.int32(img_height) * np.int32(img_height)
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio_w, best_ratio_h = ratio[0], ratio[1]
elif (
ratio_diff == best_ratio_diff
and area > 0.5 * patch_size * patch_size * ratio[0] * ratio[1]
):
best_ratio_w, best_ratio_h = ratio[0], ratio[1]

return best_ratio_w, best_ratio_h


def _split_image(
image: Image.Image,
split_image: bool,
split_ratio: List[List[int]],
patch_size: int,
) -> List[Image.Image]:
"""
Split image into multiple patches
Args:
image (PIL.Image): Input image.
split_image (bool): Whether to split the image into patches.
split_ratio (2d numpy array): dimension size (M,2)
patch_size (int): image patch size
Returns:
List[PIL.Image]: List of splitted images.
"""
if split_image:
ratio_width, ratio_height = _select_best_resolution(
image.width, image.height, split_ratio, patch_size
)
resize_width = patch_size * ratio_width
resize_height = patch_size * ratio_height
blocks = ratio_width * ratio_height
resized_img = image.resize((resize_width, resize_height))
processed_images = []
for i in range(blocks):
box = (
(i % (resize_width // patch_size)) * patch_size,
(i // (resize_width // patch_size)) * patch_size,
((i % (resize_width // patch_size)) + 1) * patch_size,
((i // (resize_width // patch_size)) + 1) * patch_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if len(processed_images) != 1:
processed_images.insert(0, image)
return processed_images
else:
return [image]


def keep_ratio_resize_and_pixel_mask(
img: Image.Image, max_size, min_size=336, padding_value=0
):
Expand Down Expand Up @@ -127,6 +208,17 @@ def __call__(
max_image_size: Optional[int] = 980,
min_image_size: Optional[int] = 336,
return_tensors: Optional[Union[str, TensorType]] = "pt",
split_image: Optional[bool] = False,
split_ratio: Optional[List[List[int]]] = [
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[2, 2],
[2, 1],
[3, 1],
[4, 1],
],
):
"""
Process a list of images.
Expand All @@ -135,13 +227,16 @@ def __call__(
images (list): List of PIL.Image objects.
max_image_size (int, optional): Override the default max image size. Defaults to None.
return_tensors (str or TensorType, optional): The type of tensor to return. Defaults to "pt".
split_image (bool, optional): Whether to split the image. Defaults to False.
split_ratio (list, optional): The ratio for splitting the image. Defaults to a list of common split ratios.
Returns:
BatchFeature: A BatchFeature object containing:
- 'pixel_values': Tensor of processed image pixel values.
- 'pixel_mask': Boolean pixel mask. This mask is a 2D tensor of shape (max_size, max_size) where:
- True (1) values indicate pixels that belong to the original resized image.
- False (0) values indicate pixels that are part of the padding.
The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
- 'num_crops': Tensor of the number of crops for each image.
"""
max_size = self.max_image_size if max_image_size is None else max_image_size
min_size = self.min_image_size if min_image_size is None else min_image_size
Expand All @@ -154,19 +249,24 @@ def __call__(

pixel_values = []
pixel_masks = []
num_crops = []

for image in images:
img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
image, max_size, min_size
)
img_padded = self.transform(img_padded)
pixel_values.append(img_padded)
pixel_masks.append(pixel_mask)
crop_images = _split_image(image, split_image, split_ratio, max_size)
num_crops.append(torch.tensor(len(crop_images)))
for crop_image in crop_images:
img_padded, pixel_mask = keep_ratio_resize_and_pixel_mask(
crop_image, max_size, min_size
)
img_padded = self.transform(img_padded)
pixel_values.append(img_padded)
pixel_masks.append(pixel_mask)

return BatchFeature(
data={
"pixel_values": torch.stack(pixel_values),
"pixel_mask": torch.stack(pixel_masks),
"num_crops": torch.stack(num_crops),
},
tensor_type=return_tensors,
)
Expand All @@ -177,10 +277,23 @@ def preprocess(
max_image_size=None,
min_image_size=None,
return_tensors: Optional[Union[str, TensorType]] = None,
split_image: Optional[bool] = False,
split_ratio: Optional[List[List[int]]] = [
[1, 1],
[1, 2],
[1, 3],
[1, 4],
[2, 2],
[2, 1],
[3, 1],
[4, 1],
],
):
return self.__call__(
images,
max_image_size=max_image_size,
min_image_size=min_image_size,
return_tensors=return_tensors,
split_image=split_image,
split_ratio=split_ratio,
)
Loading

0 comments on commit 4d70b52

Please sign in to comment.