Skip to content

Commit

Permalink
simplify the router
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 12, 2024
1 parent 4d598d1 commit 87fc97c
Showing 1 changed file with 20 additions and 221 deletions.
241 changes: 20 additions & 221 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,221 +117,14 @@ def __init__(
self.moe_num_shared_experts = moe_num_shared_experts


# copied from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/moe_utils.py#L101-L142
class MoEAuxLossAutoScaler(torch.autograd.Function):
"""An AutoScaler that compute and scales the grad for auxiliary loss."""

main_loss_backward_scale: torch.Tensor = torch.tensor(1.0)

@staticmethod
def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor):
"""Preserve the aux_loss by storing it in the context to avoid garbage collection.
Args:
output (torch.Tensor): The output tensor.
aux_loss (torch.Tensor): The auxiliary loss tensor.
Returns:
torch.Tensor: The output tensor.
"""
ctx.save_for_backward(aux_loss)
return output

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
"""Compute and scale the gradient for auxiliary loss..
Args:
grad_output (torch.Tensor): The gradient of the output.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient.
"""
(aux_loss,) = ctx.saved_tensors
aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale
scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale
return grad_output, scaled_aux_loss_grad

@staticmethod
def set_loss_scale(scale: torch.Tensor):
"""set the scale of the aux loss.
Args:
scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss.
"""
MoEAuxLossAutoScaler.main_loss_backward_scale = scale


def z_loss_func(logits, z_loss_coeff):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""

z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff
return z_loss


def switch_load_balancing_loss_func(
probs: torch.Tensor,
tokens_per_expert: torch.Tensor,
topk: int,
moe_aux_loss_coeff: float,
):
"""Calculate the auxiliary loss for better load balacing.
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
Args:
probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts]
tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts]
Returns:
torch.Tensor: The auxiliary loss for load balancing.
"""
num_tokens = probs.shape[0] * topk
num_experts = probs.shape[1]

probs_mean_per_expert = probs.mean(dim=0)
aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * (
num_experts / num_tokens * moe_aux_loss_coeff
)
return aux_loss


# adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304
class TopKRouter(nn.Module):
"""
Top-K Router for Mixture of Experts (MoE) models.
This router determines which experts should process each token based on the top-k scoring experts.
It also applies auxiliary losses to encourage load balancing among experts.
Args:
config (AriaMoELMConfig): Configuration object containing MoE-related parameters.
"""

class Experts(nn.Module):
def __init__(self, config: AriaMoELMConfig):
super().__init__()
self.config = config

self.weight = nn.Parameter(
self.router_weight = nn.Parameter(
torch.empty((self.config.moe_num_experts, self.config.hidden_size))
)
# FIXME: initialize the weight

def gating(self, input: torch.Tensor) -> torch.Tensor:
"""
Compute the gating logits for each token-expert pair.
Args:
input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size].
Returns:
torch.Tensor: Logits tensor of shape [batch_size * seq_len, num_experts].
"""
logits = torch.nn.functional.linear(input, self.weight)
return logits

def apply_z_loss(self, logits: torch.Tensor) -> torch.Tensor:
"""
Apply z-loss to encourage router logits to remain small for enhanced stability.
Args:
logits (torch.Tensor): Router logits.
Returns:
torch.Tensor: Logits with z-loss applied.
"""
z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
return logits

def apply_aux_loss(
self,
logits: torch.Tensor,
tokens_per_expert: torch.Tensor,
activation: torch.Tensor,
) -> torch.Tensor:
"""
Apply auxiliary loss for load balancing among experts.
Args:
logits (torch.Tensor): Router logits.
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
activation (torch.Tensor): Activation values.
Returns:
torch.Tensor: Activation with auxiliary loss applied.
"""
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
aux_loss = switch_load_balancing_loss_func(
probs,
tokens_per_expert,
self.config.moe_topk,
self.config.moe_aux_loss_coeff,
)
return MoEAuxLossAutoScaler.apply(activation, aux_loss)

def routing(
self, logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Perform the routing operation to determine expert assignments.
Args:
logits (torch.Tensor): Router logits.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- scores: Softmax probabilities for top-k experts.
- top_indices: Indices of top-k experts for each token.
- tokens_per_expert: Number of tokens assigned to each expert.
"""
logits = self.apply_z_loss(logits)

top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits)

tokens_per_expert = torch.histc(
top_indices.flatten(),
bins=self.config.moe_num_experts,
min=0,
max=self.config.moe_num_experts - 1,
)

scores = self.apply_aux_loss(logits, tokens_per_expert, scores)
return scores, top_indices, tokens_per_expert

def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass of the TopKRouter.
Args:
input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size].
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- scores: Softmax probabilities for top-k experts.
- top_indices: Indices of top-k experts for each token.
- tokens_per_expert: Number of tokens assigned to each expert.
"""
logits = self.gating(input)
logits = logits.view(-1, self.config.moe_num_experts)
scores, top_indices, tokens_per_expert = self.routing(logits)
return scores, top_indices, tokens_per_expert


class Experts(nn.Module):
def __init__(self, config: AriaMoELMConfig):
super().__init__()
self.config = config

self.w1 = nn.Parameter(
torch.empty(
Expand All @@ -351,16 +144,24 @@ def __init__(self, config: AriaMoELMConfig):
)
)
)
set_weight_attrs(self.router_weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w1, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w2, {"weight_loader": self.weight_loader})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight.transpose(1, 2).contiguous())
def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
):
if shard_id == "router":
param.data.copy_(loaded_weight)
else:
param.data.copy_(loaded_weight.transpose(1, 2).contiguous())

def forward(self, hidden_states):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)

def forward(self, hidden_states, gating_output):
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
def custom_routing_function(hidden_states, router_output, topk, renormalize):
top_logits, top_indices = torch.topk(
gating_output, k=self.config.moe_topk, dim=1
router_output, k=self.config.moe_topk, dim=1
)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
return scores, top_indices.to(torch.int32)
Expand All @@ -371,7 +172,7 @@ def custom_routing_function(hidden_states, gating_output, topk, renormalize):
hidden_states,
self.w1,
self.w2,
gating_output,
router_output,
self.config.moe_topk,
False,
inplace=True,
Expand Down Expand Up @@ -402,7 +203,6 @@ def __init__(
super().__init__()
self.config = config

self.router = TopKRouter(config)
self.experts = Experts(config)
self.shared_experts = LlamaMLP(
config.hidden_size,
Expand All @@ -422,10 +222,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
torch.Tensor: Output tensor after passing through the MoE layer.
"""

gating_output = self.router.gating(hidden_states)

shared_expert_output = self.shared_experts(hidden_states)
sparse_expert_output = self.experts(hidden_states, gating_output)
sparse_expert_output = self.experts(hidden_states)

return sparse_expert_output + shared_expert_output

Expand Down Expand Up @@ -985,8 +783,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("experts.w1", "experts.fc1.weight", None),
("experts.w2", "experts.fc2.weight", None),
("experts.router_weight", "router.weight", "router"),
("experts.w1", "experts.fc1.weight", "w1"),
("experts.w2", "experts.fc2.weight", "w2"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
Expand Down

0 comments on commit 87fc97c

Please sign in to comment.