Skip to content

Commit

Permalink
[core] gemma2 full context length support (vllm-project#10584)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
youkaichao authored and mfournioux committed Nov 28, 2024
1 parent 5fb7c3e commit b55bd44
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 24 deletions.
25 changes: 18 additions & 7 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test

MODELS = [
"facebook/opt-125m",
"google/gemma-2-2b-it",
"meta-llama/Llama-3.2-1B",
]

Expand All @@ -42,8 +43,6 @@ def test_vllm_gc_ed():
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
Expand All @@ -54,15 +53,27 @@ def test_models(
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")

if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
"XFORMERS does not support gemma2 with full context length.")

os.environ["VLLM_ATTENTION_BACKEND"] = backend

# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,26 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
Expand Down
29 changes: 20 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,26 @@ def __init__(
(self.hf_text_config.model_type in ["gemma2"]))

if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

print_warning_once(
f"{self.hf_text_config.model_type} has interleaved attention, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

print_warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
"XFORMERS backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
delattr(self.hf_text_config, "sliding_window")
sliding_window = None

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
Expand Down
13 changes: 7 additions & 6 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,20 @@ def __init__(self,
is_neox_style=True,
)

# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window = (layer_idx % 2 == 1
and config.sliding_window is not None)
del use_sliding_window # Unused.
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")

def forward(
Expand Down

0 comments on commit b55bd44

Please sign in to comment.