Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Change kv scaling factor by param json on nvidia gpu #11688

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/model_executor/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
5 changes: 3 additions & 2 deletions vllm/model_executor/models/granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
5 changes: 3 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,9 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.attn._kv_scale = scaling_factor
if hasattr(layer_self_attn.attn, "_k_scale"):
layer_self_attn.attn._k_scale = scaling_factor
layer_self_attn.attn._v_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")
3 changes: 2 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,8 @@ def load_model(self) -> None:
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))

if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
or current_platform.is_cuda()):
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
Expand Down
Loading