diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4664aa53ea092..f4530e4771960 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,11 +31,14 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -326,6 +329,15 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue @@ -343,6 +355,10 @@ def load_weights(self, weights: Iterable[Tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if is_pp_missing_parameter(name, self): continue param = params_dict[name]