diff --git a/aria/vllm/aria.py b/aria/vllm/aria.py index 86d14aa..4f68431 100644 --- a/aria/vllm/aria.py +++ b/aria/vllm/aria.py @@ -17,7 +17,6 @@ # specific language governing permissions and limitations # under the License. import math -import os from typing import Iterable, List, Optional, Tuple import numpy as np @@ -330,78 +329,6 @@ def forward( return scores, top_indices, tokens_per_expert -# adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587 -class TokenDispatcher: - """ - Handles the dispatching and gathering of tokens to and from experts. - - This class is responsible for permuting tokens based on expert assignments and - unpermuting them after expert processing. - - Args: - config (AriaMoELMConfig): Configuration object containing MoE-related parameters. - """ - - def __init__(self, config: AriaMoELMConfig): - self.config = config - self.hidden_states_shape = None - self.reversed_input_permutation_mapping = None - - def token_permutation( - self, hidden_states: torch.Tensor, indices: torch.Tensor - ) -> torch.Tensor: - """ - Permute tokens based on expert assignments. - - Args: - hidden_states (torch.Tensor): Input hidden states. - indices (torch.Tensor): Expert assignment indices. - - Returns: - torch.Tensor: Permuted tokens. - """ - self.hidden_states_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_states.size(-1)) - flatten_indices = indices.flatten() - sorted_indices = torch.argsort(flatten_indices, stable=True) - permuted_tokens = hidden_states.index_select( - 0, sorted_indices // self.config.moe_topk - ) - self.reversed_input_permutation_mapping = sorted_indices - return permuted_tokens - - def token_unpermutation( - self, permuted_tokens: torch.Tensor, scores: torch.Tensor - ) -> torch.Tensor: - """ - Unpermute tokens and combine expert outputs. - - Args: - permuted_tokens (torch.Tensor): Tokens after expert processing. - scores (torch.Tensor): Expert assignment scores. - - Returns: - torch.Tensor: Unpermuted and combined output. - """ - num_unpermuted_tokens = scores.numel() - unpermuted_tokens = torch.zeros( - (num_unpermuted_tokens, permuted_tokens.size(1)), - dtype=permuted_tokens.dtype, - device=permuted_tokens.device, - ) - unpermuted_tokens.index_copy_( - 0, self.reversed_input_permutation_mapping, permuted_tokens - ) - unpermuted_tokens = unpermuted_tokens.reshape( - -1, self.config.moe_topk, permuted_tokens.size(1) - ) - - unpermuted_tokens = unpermuted_tokens * scores.unsqueeze(-1) - unpermuted_tokens = unpermuted_tokens.sum(dim=1).type_as(permuted_tokens) - output = unpermuted_tokens.view(self.hidden_states_shape) - return output - - def sequential_gemm(input, weight, tokens_per_expert): """ Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts. @@ -435,21 +362,6 @@ def sequential_gemm(input, weight, tokens_per_expert): return output -try: - from grouped_gemm.ops import gmm as experts_gemm - - if os.environ.get("USE_GROUPED_GEMM", "1") == "0": - logger.warning( - "environment variable USE_GROUPED_GEMM is set to 0, using sequential GEMM instead." - ) - experts_gemm = sequential_gemm -except ImportError: - logger.warning( - "`grouped_gemm` is not installed, using sequential GEMM, which is slower." - ) - experts_gemm = sequential_gemm - - class GroupedGEMM(nn.Module): """ Grouped GEMM (General Matrix Multiplication) module for efficient expert computation. @@ -589,7 +501,6 @@ def __init__( self.config = config self.router = TopKRouter(config) - self.token_dispatcher = TokenDispatcher(config) self.experts = GroupedMLP(config) self.shared_experts = LlamaMLP( config.hidden_size,