Skip to content

Commit

Permalink
[Bugfix] Add kv cache scales to gemma2.py (vllm-project#11269)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Dec 23, 2024
1 parent 63afbe9 commit 60fb4f3
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -326,6 +329,15 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
Expand All @@ -343,6 +355,10 @@ def load_weights(self, weights: Iterable[Tuple[str,
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
Expand Down

0 comments on commit 60fb4f3

Please sign in to comment.