From 165212d23a1201090835a083994b4c31bbd55051 Mon Sep 17 00:00:00 2001 From: bjmsong Date: Thu, 2 Jan 2025 22:45:51 +0800 Subject: [PATCH 1/2] Change kv scaling factor by param json on nvidia gpu Signed-off-by: bjmsong --- vllm/model_executor/models/exaone.py | 5 +++-- vllm/model_executor/models/granite.py | 5 +++-- vllm/model_executor/models/llama.py | 5 +++-- vllm/model_executor/models/solar.py | 5 +++-- vllm/worker/model_runner.py | 2 +- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 0398f0943a70a..8324a563edd64 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -606,8 +606,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index f9e0443b9a508..a91ed4158a73f 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -545,8 +545,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2902e6999c2fd..8623da99574bb 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -452,8 +452,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index caae0b65d7d10..a7cf65a0e36e4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -565,8 +565,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b545d1b28bd2..3fc3b0647c5c8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1136,7 +1136,7 @@ def load_model(self) -> None: self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and current_platform.is_rocm(): + if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm() or current_platform.is_cuda()): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated # in the future. From baa5b0abdc9f3aa833dfe6977b56050fff8d1344 Mon Sep 17 00:00:00 2001 From: bjmsong Date: Thu, 2 Jan 2025 23:02:07 +0800 Subject: [PATCH 2/2] format Signed-off-by: bjmsong --- vllm/worker/model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3fc3b0647c5c8..637fba23611f4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1136,7 +1136,8 @@ def load_model(self) -> None: self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm() or current_platform.is_cuda()): + if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm() + or current_platform.is_cuda()): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated # in the future.