Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 12, 2024
1 parent 3acb805 commit 627d24c
Showing 1 changed file with 0 additions and 89 deletions.
89 changes: 0 additions & 89 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# specific language governing permissions and limitations
# under the License.
import math
import os
from typing import Iterable, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -330,78 +329,6 @@ def forward(
return scores, top_indices, tokens_per_expert


# adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
class TokenDispatcher:
"""
Handles the dispatching and gathering of tokens to and from experts.
This class is responsible for permuting tokens based on expert assignments and
unpermuting them after expert processing.
Args:
config (AriaMoELMConfig): Configuration object containing MoE-related parameters.
"""

def __init__(self, config: AriaMoELMConfig):
self.config = config
self.hidden_states_shape = None
self.reversed_input_permutation_mapping = None

def token_permutation(
self, hidden_states: torch.Tensor, indices: torch.Tensor
) -> torch.Tensor:
"""
Permute tokens based on expert assignments.
Args:
hidden_states (torch.Tensor): Input hidden states.
indices (torch.Tensor): Expert assignment indices.
Returns:
torch.Tensor: Permuted tokens.
"""
self.hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
flatten_indices = indices.flatten()
sorted_indices = torch.argsort(flatten_indices, stable=True)
permuted_tokens = hidden_states.index_select(
0, sorted_indices // self.config.moe_topk
)
self.reversed_input_permutation_mapping = sorted_indices
return permuted_tokens

def token_unpermutation(
self, permuted_tokens: torch.Tensor, scores: torch.Tensor
) -> torch.Tensor:
"""
Unpermute tokens and combine expert outputs.
Args:
permuted_tokens (torch.Tensor): Tokens after expert processing.
scores (torch.Tensor): Expert assignment scores.
Returns:
torch.Tensor: Unpermuted and combined output.
"""
num_unpermuted_tokens = scores.numel()
unpermuted_tokens = torch.zeros(
(num_unpermuted_tokens, permuted_tokens.size(1)),
dtype=permuted_tokens.dtype,
device=permuted_tokens.device,
)
unpermuted_tokens.index_copy_(
0, self.reversed_input_permutation_mapping, permuted_tokens
)
unpermuted_tokens = unpermuted_tokens.reshape(
-1, self.config.moe_topk, permuted_tokens.size(1)
)

unpermuted_tokens = unpermuted_tokens * scores.unsqueeze(-1)
unpermuted_tokens = unpermuted_tokens.sum(dim=1).type_as(permuted_tokens)
output = unpermuted_tokens.view(self.hidden_states_shape)
return output


def sequential_gemm(input, weight, tokens_per_expert):
"""
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
Expand Down Expand Up @@ -435,21 +362,6 @@ def sequential_gemm(input, weight, tokens_per_expert):
return output


try:
from grouped_gemm.ops import gmm as experts_gemm

if os.environ.get("USE_GROUPED_GEMM", "1") == "0":
logger.warning(
"environment variable USE_GROUPED_GEMM is set to 0, using sequential GEMM instead."
)
experts_gemm = sequential_gemm
except ImportError:
logger.warning(
"`grouped_gemm` is not installed, using sequential GEMM, which is slower."
)
experts_gemm = sequential_gemm


class GroupedGEMM(nn.Module):
"""
Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
Expand Down Expand Up @@ -589,7 +501,6 @@ def __init__(
self.config = config

self.router = TopKRouter(config)
self.token_dispatcher = TokenDispatcher(config)
self.experts = GroupedMLP(config)
self.shared_experts = LlamaMLP(
config.hidden_size,
Expand Down

0 comments on commit 627d24c

Please sign in to comment.