From 1e10c287cff389e3e3c7d339bf2ef0ac9f5ba942 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 23 Nov 2024 11:57:10 +0100 Subject: [PATCH 1/5] Up --- vllm/model_executor/models/llama.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 66b29e72cfa89..afc8f1203fa07 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -167,6 +167,16 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) + + layer_idx: int = int(prefix.split(".")[0]) + if isinstance(config.interleaved_sliding_window, int): + sliding_window = config.interleaved_sliding_window + elif isinstance(config.interleaved_sliding_window, list): + sw_idx = layer_idx % len(sliding_window) + sliding_window = config.interleaved_sliding_window[sw_idx] + else: + None + self.attn = Attention( self.num_heads, self.head_dim, @@ -174,6 +184,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", ) From f7d561b0bb9e5bed0cef20b71e285ea3f5f03f40 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 29 Nov 2024 13:52:16 +0000 Subject: [PATCH 2/5] WIP --- vllm/model_executor/models/llama.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index afc8f1203fa07..f7eb0c976daf5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -110,6 +110,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, + layer_idx: Optional[int] = None, prefix: str = "", ) -> None: super().__init__() @@ -168,14 +169,17 @@ def __init__( is_neox_style=is_neox_style, ) - layer_idx: int = int(prefix.split(".")[0]) - if isinstance(config.interleaved_sliding_window, int): - sliding_window = config.interleaved_sliding_window - elif isinstance(config.interleaved_sliding_window, list): - sw_idx = layer_idx % len(sliding_window) - sliding_window = config.interleaved_sliding_window[sw_idx] + if hasattr(config, "interleaved_sliding_window"): + if isinstance(config.interleaved_sliding_window, int): + sliding_window = config.interleaved_sliding_window + elif isinstance(config.interleaved_sliding_window, list): + assert layer_idx is not None + sw_idx = layer_idx % len(config.interleaved_sliding_window) + sliding_window = config.interleaved_sliding_window[sw_idx] + else: + raise ValueError(f"{type(sliding_window)} is not suuported.") else: - None + sliding_window = None self.attn = Attention( self.num_heads, @@ -226,6 +230,8 @@ def __init__( # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) + layer_idx: int = int(prefix.split(".")[-1]) + self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -238,6 +244,7 @@ def __init__( quant_config=quant_config, bias=attention_bias, cache_config=cache_config, + layer_idx=layer_idx, prefix=f"{prefix}.self_attn", ) self.mlp = LlamaMLP( From 9724f0199524e3d14b05a0bf52e906b82b1b733c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 29 Nov 2024 14:10:11 +0000 Subject: [PATCH 3/5] WIP --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f7eb0c976daf5..ada17c35f2b91 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -177,7 +177,7 @@ def __init__( sw_idx = layer_idx % len(config.interleaved_sliding_window) sliding_window = config.interleaved_sliding_window[sw_idx] else: - raise ValueError(f"{type(sliding_window)} is not suuported.") + raise ValueError(f"{type(sliding_window)} is not supported.") else: sliding_window = None From 628f56e2696e93ebf0d528599e3ecd2e27d494ec Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 29 Nov 2024 21:55:17 -0800 Subject: [PATCH 4/5] use extract_layer_index Signed-off-by: youkaichao --- vllm/model_executor/models/llama.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 84bef797db99c..04db8b7888e5f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,7 +54,7 @@ from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -111,10 +111,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, - layer_idx: Optional[int] = None, prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -174,7 +174,6 @@ def __init__( if isinstance(config.interleaved_sliding_window, int): sliding_window = config.interleaved_sliding_window elif isinstance(config.interleaved_sliding_window, list): - assert layer_idx is not None sw_idx = layer_idx % len(config.interleaved_sliding_window) sliding_window = config.interleaved_sliding_window[sw_idx] else: @@ -231,7 +230,6 @@ def __init__( # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) - layer_idx: int = int(prefix.split(".")[-1]) self.self_attn = LlamaAttention( config=config, @@ -245,7 +243,6 @@ def __init__( quant_config=quant_config, bias=attention_bias, cache_config=cache_config, - layer_idx=layer_idx, prefix=f"{prefix}.self_attn", ) self.mlp = LlamaMLP( From 1b62f3674003f31f037eb734bd5fd0d769c28144 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 29 Nov 2024 21:56:31 -0800 Subject: [PATCH 5/5] minimize change Signed-off-by: youkaichao --- vllm/model_executor/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 04db8b7888e5f..ff0ab011a9158 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -230,7 +230,6 @@ def __init__( # Support internlm/internlm-7b with bias attention_bias = getattr(config, "attention_bias", False) or getattr( config, "bias", False) - self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size,