Skip to content

Commit

Permalink
q_range as an env
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Dec 19, 2024
1 parent 37f37d1 commit a283f40
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix

self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

Expand Down Expand Up @@ -176,7 +177,7 @@ def forward(
self.layer_name)

def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.k_range)
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
# We only calculate the scales once
Expand Down
17 changes: 11 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@
VLLM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
Q_SCALE_CONSTANT: int = 20
K_SCALE_CONSTANT: int = 20
V_SCALE_CONSTANT: int = 10


def get_default_cache_root():
Expand Down Expand Up @@ -529,13 +530,17 @@ def get_default_config_root():
"VLLM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))),

# Divisor for dynamic key scale factor calculation for FP8 KV Cache
# Divisor for dynamic query scale factor calculation for FP8 attention
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")),

# Divisor for dynamic key scale factor calculation for FP8 KV Cache and attention
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
lambda: int(os.getenv("K_SCALE_CONSTANT", "20")),

# Divisor for dynamic value scale factor calculation for FP8 KV Cache
# Divisor for dynamic value scale factor calculation for FP8 KV Cache and attention
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
lambda: int(os.getenv("V_SCALE_CONSTANT", "10")),

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
Expand Down

0 comments on commit a283f40

Please sign in to comment.