Skip to content

Commit

Permalink
fix(gptfast): align the implementation and config of Aria HF version
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 13, 2024
1 parent a2838e8 commit 17ab774
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions gptfast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ModelArgs:
intermediate_size: int = 1664
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
rope_base: float = 5000000
norm_eps: float = 1e-5
use_scaled_rope: bool = False
num_experts: int = 64
Expand Down Expand Up @@ -357,11 +357,10 @@ def forward(self, x: Tensor) -> Tensor:
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
# x: [T, D]
scores = self.gate(x) # [T, E]
expert_weights = F.softmax(scores, dim=-1)
expert_weights, expert_indices = torch.topk(
expert_weights, self.num_activated_experts, dim=-1
scores, self.num_activated_experts, dim=-1
) # [T, A], [T, A]
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
expert_weights = F.softmax(expert_weights, dim=-1)
expert_outs = self.cond_ffn(x, expert_indices, expert_weights)
shared_outs = self.shared_ffn(x)
return expert_outs + shared_outs
Expand Down

0 comments on commit 17ab774

Please sign in to comment.