Skip to content

Commit

Permalink
Fix LLaVA-NeXT feature size calculation (for real)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Jan 6, 2025
1 parent 905ef01 commit 629c3e7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def processor_for_llava_next():

@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198)])
(488, 183), (198, 176), (176, 198),
(161, 184), (184, 161)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_next,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def processor_for_llava_onevision():
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198)])
(488, 183), (198, 176), (176, 198),
(161, 184), (184, 161)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_onevision,
Expand Down
25 changes: 12 additions & 13 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,30 +121,29 @@ def _get_num_unpadded_features(
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width

# NOTE: Use float32 to remain consistent with HF output
original_aspect_ratio = np.array(original_width / original_height,
dtype=np.float32)
current_aspect_ratio = np.array(current_width / current_height,
dtype=np.float32)
current_height = np.float32(npatches * num_patch_height)
current_width = np.float32(npatches * num_patch_width)

original_width = np.float32(original_width) # type: ignore
original_height = np.float32(original_height) # type: ignore

original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height

if original_aspect_ratio > current_aspect_ratio:
scale_factor = np.array(current_width / original_width,
dtype=np.float32)
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = np.array(current_height / original_height,
dtype=np.float32)
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding

unpadded_features = current_height * current_width
newline_features = current_height
unpadded_features = int(current_height * current_width)
newline_features = int(current_height)

return (unpadded_features, newline_features)

Expand Down
25 changes: 12 additions & 13 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,30 +104,29 @@ def _get_num_unpadded_features(
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width

# NOTE: Use float32 to remain consistent with HF output
original_aspect_ratio = np.array(original_width / original_height,
dtype=np.float32)
current_aspect_ratio = np.array(current_width / current_height,
dtype=np.float32)
current_height = np.float32(npatches * num_patch_height)
current_width = np.float32(npatches * num_patch_width)

original_width = np.float32(original_width) # type: ignore
original_height = np.float32(original_height) # type: ignore

original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height

if original_aspect_ratio > current_aspect_ratio:
scale_factor = np.array(current_width / original_width,
dtype=np.float32)
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
current_height -= 2 * padding
else:
scale_factor = np.array(current_height / original_height,
dtype=np.float32)
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
current_width -= 2 * padding

unpadded_features = current_height * current_width
newline_features = current_height
unpadded_features = int(current_height * current_width)
newline_features = int(current_height)

ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
if ratio > 1.1:
Expand Down

0 comments on commit 629c3e7

Please sign in to comment.