diff --git a/aria/vllm/aria.py b/aria/vllm/aria.py index 2c7786f..ee669c0 100644 --- a/aria/vllm/aria.py +++ b/aria/vllm/aria.py @@ -117,221 +117,14 @@ def __init__( self.moe_num_shared_experts = moe_num_shared_experts -# copied from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/moe_utils.py#L101-L142 -class MoEAuxLossAutoScaler(torch.autograd.Function): - """An AutoScaler that compute and scales the grad for auxiliary loss.""" - - main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) - - @staticmethod - def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): - """Preserve the aux_loss by storing it in the context to avoid garbage collection. - - Args: - output (torch.Tensor): The output tensor. - aux_loss (torch.Tensor): The auxiliary loss tensor. - - Returns: - torch.Tensor: The output tensor. - """ - ctx.save_for_backward(aux_loss) - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - """Compute and scale the gradient for auxiliary loss.. - - Args: - grad_output (torch.Tensor): The gradient of the output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. - """ - (aux_loss,) = ctx.saved_tensors - aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale - scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale - return grad_output, scaled_aux_loss_grad - - @staticmethod - def set_loss_scale(scale: torch.Tensor): - """set the scale of the aux loss. - - Args: - scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss. - """ - MoEAuxLossAutoScaler.main_loss_backward_scale = scale - - -def z_loss_func(logits, z_loss_coeff): - """Encourages the router's logits to remain small to enhance stability. - Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. - - Args: - logits (torch.Tensor): The logits of the router. - - Returns: - torch.Tensor: The logits after applying the z-loss. - """ - - z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff - return z_loss - - -def switch_load_balancing_loss_func( - probs: torch.Tensor, - tokens_per_expert: torch.Tensor, - topk: int, - moe_aux_loss_coeff: float, -): - """Calculate the auxiliary loss for better load balacing. - Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details. - - Args: - probs (torch.Tensor): The softmax probs output by the router for each token. [num_tokens, num_experts] - tokens_per_expert (torch.Tensor): The number of assigned tokens for each expert. [num_experts] - - Returns: - torch.Tensor: The auxiliary loss for load balancing. - """ - num_tokens = probs.shape[0] * topk - num_experts = probs.shape[1] - - probs_mean_per_expert = probs.mean(dim=0) - aux_loss = torch.sum(probs_mean_per_expert * tokens_per_expert) * ( - num_experts / num_tokens * moe_aux_loss_coeff - ) - return aux_loss - - -# adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/router.py#L96-L304 -class TopKRouter(nn.Module): - """ - Top-K Router for Mixture of Experts (MoE) models. - - This router determines which experts should process each token based on the top-k scoring experts. - It also applies auxiliary losses to encourage load balancing among experts. - - Args: - config (AriaMoELMConfig): Configuration object containing MoE-related parameters. - """ - +class Experts(nn.Module): def __init__(self, config: AriaMoELMConfig): super().__init__() self.config = config - self.weight = nn.Parameter( + self.router_weight = nn.Parameter( torch.empty((self.config.moe_num_experts, self.config.hidden_size)) ) - # FIXME: initialize the weight - - def gating(self, input: torch.Tensor) -> torch.Tensor: - """ - Compute the gating logits for each token-expert pair. - - Args: - input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. - - Returns: - torch.Tensor: Logits tensor of shape [batch_size * seq_len, num_experts]. - """ - logits = torch.nn.functional.linear(input, self.weight) - return logits - - def apply_z_loss(self, logits: torch.Tensor) -> torch.Tensor: - """ - Apply z-loss to encourage router logits to remain small for enhanced stability. - - Args: - logits (torch.Tensor): Router logits. - - Returns: - torch.Tensor: Logits with z-loss applied. - """ - z_loss = z_loss_func(logits, self.config.moe_z_loss_coeff) - logits = MoEAuxLossAutoScaler.apply(logits, z_loss) - return logits - - def apply_aux_loss( - self, - logits: torch.Tensor, - tokens_per_expert: torch.Tensor, - activation: torch.Tensor, - ) -> torch.Tensor: - """ - Apply auxiliary loss for load balancing among experts. - - Args: - logits (torch.Tensor): Router logits. - tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. - activation (torch.Tensor): Activation values. - - Returns: - torch.Tensor: Activation with auxiliary loss applied. - """ - probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - aux_loss = switch_load_balancing_loss_func( - probs, - tokens_per_expert, - self.config.moe_topk, - self.config.moe_aux_loss_coeff, - ) - return MoEAuxLossAutoScaler.apply(activation, aux_loss) - - def routing( - self, logits: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Perform the routing operation to determine expert assignments. - - Args: - logits (torch.Tensor): Router logits. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - scores: Softmax probabilities for top-k experts. - - top_indices: Indices of top-k experts for each token. - - tokens_per_expert: Number of tokens assigned to each expert. - """ - logits = self.apply_z_loss(logits) - - top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) - scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).type_as(logits) - - tokens_per_expert = torch.histc( - top_indices.flatten(), - bins=self.config.moe_num_experts, - min=0, - max=self.config.moe_num_experts - 1, - ) - - scores = self.apply_aux_loss(logits, tokens_per_expert, scores) - return scores, top_indices, tokens_per_expert - - def forward( - self, input: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward pass of the TopKRouter. - - Args: - input (torch.Tensor): Input tensor of shape [batch_size * seq_len, hidden_size]. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - scores: Softmax probabilities for top-k experts. - - top_indices: Indices of top-k experts for each token. - - tokens_per_expert: Number of tokens assigned to each expert. - """ - logits = self.gating(input) - logits = logits.view(-1, self.config.moe_num_experts) - scores, top_indices, tokens_per_expert = self.routing(logits) - return scores, top_indices, tokens_per_expert - - -class Experts(nn.Module): - def __init__(self, config: AriaMoELMConfig): - super().__init__() - self.config = config self.w1 = nn.Parameter( torch.empty( @@ -351,16 +144,24 @@ def __init__(self, config: AriaMoELMConfig): ) ) ) + set_weight_attrs(self.router_weight, {"weight_loader": self.weight_loader}) set_weight_attrs(self.w1, {"weight_loader": self.weight_loader}) set_weight_attrs(self.w2, {"weight_loader": self.weight_loader}) - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): - param.data.copy_(loaded_weight.transpose(1, 2).contiguous()) + def weight_loader( + self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str + ): + if shard_id == "router": + param.data.copy_(loaded_weight) + else: + param.data.copy_(loaded_weight.transpose(1, 2).contiguous()) + + def forward(self, hidden_states): + router_output = torch.nn.functional.linear(hidden_states, self.router_weight) - def forward(self, hidden_states, gating_output): - def custom_routing_function(hidden_states, gating_output, topk, renormalize): + def custom_routing_function(hidden_states, router_output, topk, renormalize): top_logits, top_indices = torch.topk( - gating_output, k=self.config.moe_topk, dim=1 + router_output, k=self.config.moe_topk, dim=1 ) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32) return scores, top_indices.to(torch.int32) @@ -371,7 +172,7 @@ def custom_routing_function(hidden_states, gating_output, topk, renormalize): hidden_states, self.w1, self.w2, - gating_output, + router_output, self.config.moe_topk, False, inplace=True, @@ -402,7 +203,6 @@ def __init__( super().__init__() self.config = config - self.router = TopKRouter(config) self.experts = Experts(config) self.shared_experts = LlamaMLP( config.hidden_size, @@ -422,10 +222,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: torch.Tensor: Output tensor after passing through the MoE layer. """ - gating_output = self.router.gating(hidden_states) - shared_expert_output = self.shared_experts(hidden_states) - sparse_expert_output = self.experts(hidden_states, gating_output) + sparse_expert_output = self.experts(hidden_states) return sparse_expert_output + shared_expert_output @@ -985,8 +783,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), - ("experts.w1", "experts.fc1.weight", None), - ("experts.w2", "experts.fc2.weight", None), + ("experts.router_weight", "router.weight", "router"), + ("experts.w1", "experts.fc1.weight", "w1"), + ("experts.w2", "experts.fc2.weight", "w2"), ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: