Skip to content

Commit

Permalink
Separate out profiling-related logic
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Jan 5, 2025
1 parent 4068f4b commit 09bbb1e
Show file tree
Hide file tree
Showing 17 changed files with 1,039 additions and 745 deletions.
7 changes: 4 additions & 3 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,17 +586,18 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
)

processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info

mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.get_supported_mm_limits = mock_supported_mm_limits
profiler.get_supported_mm_limits = mock_supported_mm_limits

if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="this model only supports")

with exc_ctx:
processor._get_and_validate_dummy_mm_counts()
profiler.get_mm_limits()


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
Expand Down Expand Up @@ -723,7 +724,7 @@ def _test_processing_cache_correctness(
}

mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_processor_inputs(
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
Expand Down
85 changes: 50 additions & 35 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
Expand Down Expand Up @@ -444,17 +445,57 @@ def build_mm_projector(config: PretrainedConfig):
)


class AriaMultiModalProcessor(BaseMultiModalProcessor):
class AriaProcessingMixin(ProcessingMixin):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_hf_config(self):
return self.ctx.get_hf_config()

def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
def _get_vision_config(self) -> AriaVisionConfig:
return self._get_hf_config().vision_config

def get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())


class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}

def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}

def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
vision_config = self._get_vision_config()

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)


class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):

def _get_profiling_info(self) -> BaseProfilingInfo:
return AriaProfilingInfo(self.ctx)

def _get_mm_fields_config(
self,
Expand All @@ -472,10 +513,10 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config()
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index

num_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.get_num_image_tokens()

return [
PromptReplacement(
Expand All @@ -485,32 +526,6 @@ def _get_prompt_replacements(
)
]

def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config()
vision_config: AriaVisionConfig = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore

return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)


@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
Expand Down
82 changes: 46 additions & 36 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2Processor,
Blip2QFormerConfig, apply_chunking_to_forward)
from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward)

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
Expand All @@ -18,8 +18,9 @@
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors

from .blip import BlipVisionModel
Expand Down Expand Up @@ -396,20 +397,52 @@ def forward(
return sequence_output


class Blip2MultiModalProcessor(BaseMultiModalProcessor):
class Blip2ProcessingMixin(ProcessingMixin):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)

def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
def get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens


class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}

def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
vision_config = hf_config.vision_config

def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)


class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):

def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)

def _get_mm_fields_config(
self,
Expand All @@ -427,13 +460,13 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.get_num_image_tokens()

return [
PromptReplacement(
modality="image",
target="</s>",
replacement="<image>" * max_image_tokens + "</s>",
replacement="<image>" * num_image_tokens + "</s>",
)
]

Expand All @@ -457,29 +490,6 @@ def apply(

return result

def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config

max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)


@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
Expand Down
78 changes: 46 additions & 32 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once

Expand All @@ -48,20 +49,55 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""


class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
class ChameleonProcessingMixin(ProcessingMixin):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)

def _get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)

def _get_num_image_tokens(self) -> int:
def get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length


class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):

def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}

def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self._get_hf_config()

width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=width,
height=height,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)


class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):

def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)

def _get_mm_fields_config(
self,
Expand All @@ -76,42 +112,20 @@ def _get_prompt_replacements(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor()
processor = self._get_hf_processor(**hf_processor_mm_kwargs)

return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_token * self.get_num_image_tokens(),
processor.image_end_token,
]),
)
]

def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig)

width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=width,
height=height,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)

def apply(
self,
prompt_text: str,
Expand Down
Loading

0 comments on commit 09bbb1e

Please sign in to comment.