diff --git a/gptfast/model.py b/gptfast/model.py index f803f15..7fa029c 100644 --- a/gptfast/model.py +++ b/gptfast/model.py @@ -45,7 +45,7 @@ class ModelArgs: intermediate_size: int = 1664 n_local_heads: int = -1 head_dim: int = 64 - rope_base: float = 10000 + rope_base: float = 5000000 norm_eps: float = 1e-5 use_scaled_rope: bool = False num_experts: int = 64 @@ -357,11 +357,10 @@ def forward(self, x: Tensor) -> Tensor: # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts # x: [T, D] scores = self.gate(x) # [T, E] - expert_weights = F.softmax(scores, dim=-1) expert_weights, expert_indices = torch.topk( - expert_weights, self.num_activated_experts, dim=-1 + scores, self.num_activated_experts, dim=-1 ) # [T, A], [T, A] - expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] + expert_weights = F.softmax(expert_weights, dim=-1) expert_outs = self.cond_ffn(x, expert_indices, expert_weights) shared_outs = self.shared_ffn(x) return expert_outs + shared_outs