Skip to content

Commit

Permalink
deepseek overflow fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Concurrensee committed Jan 6, 2025
1 parent a264693 commit 114f006
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
router_logits=router_logits)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = final_hidden_states + shared_output * (1. / self.routed_scaling_factor)

Check failure on line 154 in vllm/model_executor/models/deepseek_v2.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/model_executor/models/deepseek_v2.py:154:81: E501 Line too long (105 > 80)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
Expand Down Expand Up @@ -375,6 +375,7 @@ def __init__(
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor

def forward(
self,
Expand All @@ -399,9 +400,14 @@ def forward(
)

# Fully Connected
if isinstance(self.mlp, DeepseekV2MoE):
hidden_states *= 1. / self.mlp.routed_scaling_factor
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, DeepseekV2MLP):
hidden_states *= 1. / self.routed_scaling_factor
residual *= 1. / self.routed_scaling_factor
return hidden_states, residual


Expand Down

0 comments on commit 114f006

Please sign in to comment.