From 7908e9b60f59564a02b225932da1472808dd5c51 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 19 Dec 2024 22:09:00 +0000 Subject: [PATCH] format --- vllm/envs.py | 6 ++++-- vllm/model_executor/layers/quantization/kv_cache.py | 5 ++--- vllm/model_executor/models/llama.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index a0d4817c36ae3..70b189be974ab 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -534,11 +534,13 @@ def get_default_config_root(): "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache and attention + # Divisor for dynamic key scale factor calculation + # for FP8 KV Cache and attention "K_SCALE_CONSTANT": lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache and attention + # Divisor for dynamic value scale factor calculation + # for FP8 KV Cache and attention "V_SCALE_CONSTANT": lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index c96192b250d2c..bcb28193a4056 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -68,10 +68,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: v_scale *= 2 layer.calculate_kv_scales = False - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): + if not isinstance(k_scale, float) or not isinstance(v_scale, float): raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") + "for fp8 KV cache") # These are used in the final Attention.forward() layer._k_scale.copy_(k_scale) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9e7bf84ae912e..69d63216d379d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -201,7 +201,7 @@ def __init__( self.attn_fp8 = envs.VLLM_USE_ROCM_FP8_ATTN \ and current_platform.is_rocm() \ and not is_navi() \ - and isinstance(quant_config, Fp8Config) + and isinstance(quant_config, Fp8Config) self.attn = Attention( self.num_heads,