Skip to content

Commit

Permalink
refactor(vllm): remove grouped_gemm part
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 12, 2024
1 parent 627d24c commit 4d598d1
Showing 1 changed file with 50 additions and 167 deletions.
217 changes: 50 additions & 167 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import LlamaConfig
from transformers.utils import logging
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import divide, get_pp_group, tensor_model_parallel_all_reduce
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, LLMInputs
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -329,154 +328,57 @@ def forward(
return scores, top_indices, tokens_per_expert


def sequential_gemm(input, weight, tokens_per_expert):
"""
Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
Args:
input (torch.Tensor): Input tensor of shape (num_tokens, in_features).
weight (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
Returns:
torch.Tensor: Output tensor of shape (num_tokens, out_features).
"""
num_tokens = input.shape[0]
out_features = weight.shape[-1]
output = torch.zeros(
num_tokens, out_features, dtype=input.dtype, device=input.device
)

cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))

for expert_num in range(weight.shape[0]):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
tokens = input[start:end]

out = torch.matmul(tokens, weight[expert_num])
output[start:end] = out
return output


class GroupedGEMM(nn.Module):
"""
Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
for optimized performance. If the grouped_gemm library is not installed, it gracefully
falls back to a sequential GEMM implementation, which may be slower but ensures
functionality.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
groups (int): Number of expert groups.
"""

def __init__(self, in_features, out_features, groups, tp_dim):
super().__init__()
# self.tp_size = get_tensor_model_parallel_world_size()
# self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = 1
self.tp_rank = 0
self.tp_dim = tp_dim
self.in_features = in_features
self.out_features = out_features
self.groups = groups
self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
if self.tp_size > 1:
if self.tp_dim == "row":
shard_size = self.in_features
start_idx = self.tp_rank * shard_size
param.data.copy_(loaded_weight.narrow(1, start_idx, shard_size))
else:
ups, gates = [], []
for g in range(self.groups):
up, gate = loaded_weight[g].chunk(2, -1)
ups.append(up.chunk(self.tp_size, -1)[self.tp_rank])
gates.append(gate.chunk(self.tp_size, -1)[self.tp_rank])
ups, gates = torch.stack(ups), torch.stack(gates)
weights = torch.cat([ups, gates], dim=-1)
param.data.copy_(weights)
else:
param.data.copy_(loaded_weight.transpose(1, 2).contiguous())

def forward(self, input, tokens_per_expert):
"""
Perform grouped matrix multiplication.
Args:
input (torch.Tensor): Input tensor of shape (num_tokens, in_features).
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
Returns:
torch.Tensor: Output tensor of shape (num_tokens, out_features).
"""
return experts_gemm(input, self.weight, tokens_per_expert)


class GroupedMLP(nn.Module):
"""
Grouped MLP module for Mixture of Experts.
Args:
config (AriaMoELMConfig): Configuration object for the model.
"""

def __init__(self, config: AriaMoELMConfig) -> None:
class Experts(nn.Module):
def __init__(self, config: AriaMoELMConfig):
super().__init__()
# tp_size = get_tensor_model_parallel_world_size()
tp_size = 1
self.config = config
self.fc1 = GroupedGEMM(
divide(config.moe_intermediate_size * 2, tp_size),
config.hidden_size,
config.moe_num_experts,
"col",
)
self.fc2 = GroupedGEMM(
config.hidden_size,
divide(config.moe_intermediate_size, tp_size),
config.moe_num_experts,
"row",
)

def glu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]

# Manually hook the forward pass to perform tensor model parallel all-reduce
if tp_size > 1:
self.register_forward_hook(
lambda _, __, output_parallel: tensor_model_parallel_all_reduce(
output_parallel
self.w1 = nn.Parameter(
torch.empty(
(
config.moe_num_experts,
config.moe_intermediate_size * 2,
config.hidden_size,
)
)
)
self.w2 = nn.Parameter(
torch.empty(
(
config.moe_num_experts,
config.hidden_size,
config.moe_intermediate_size,
)
)
)
set_weight_attrs(self.w1, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w2, {"weight_loader": self.weight_loader})

self.activation_func = glu

def forward(self, permuted_tokens, tokens_per_expert):
"""
Forward pass of the Grouped MLP.
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight.transpose(1, 2).contiguous())

Args:
permuted_tokens (torch.Tensor): Permuted input tokens.
tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
def forward(self, hidden_states, gating_output):
def custom_routing_function(hidden_states, gating_output, topk, renormalize):
top_logits, top_indices = torch.topk(
gating_output, k=self.config.moe_topk, dim=1
)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
return scores, top_indices.to(torch.int32)

Returns:
torch.Tensor: Output tensor after passing through the MLP.
"""
tokens_per_expert = tokens_per_expert.cpu()
fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
fc1_output = self.activation_func(fc1_output)
fc2_output = self.fc2(fc1_output, tokens_per_expert)
return fc2_output
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
final_hidden_states = fused_moe(
hidden_states,
self.w1,
self.w2,
gating_output,
self.config.moe_topk,
False,
inplace=True,
custom_routing_function=custom_routing_function,
)
final_hidden_states = final_hidden_states.view(hidden_states_shape)
return final_hidden_states


class MoELayer(nn.Module):
Expand All @@ -501,7 +403,7 @@ def __init__(
self.config = config

self.router = TopKRouter(config)
self.experts = GroupedMLP(config)
self.experts = Experts(config)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.moe_intermediate_size * config.moe_num_shared_experts,
Expand All @@ -520,33 +422,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
torch.Tensor: Output tensor after passing through the MoE layer.
"""

def custom_routing_function(hidden_states, gating_output, topk, renormalize):
top_logits, top_indices = torch.topk(
gating_output, k=self.config.moe_topk, dim=1
)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
return scores, top_indices.to(torch.int32)
gating_output = self.router.gating(hidden_states)

shared_expert_output = self.shared_experts(hidden_states)
sparse_expert_output = self.experts(hidden_states, gating_output)

hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
router_logits = self.router.gating(hidden_states)
w1 = self.experts.fc1.weight
w2 = self.experts.fc2.weight
final_hidden_states = fused_moe(
hidden_states,
w1,
w2,
router_logits,
self.config.moe_topk,
False,
inplace=True,
custom_routing_function=custom_routing_function,
)
final_hidden_states = final_hidden_states.view(hidden_states_shape)
final_hidden_states += shared_expert_output
return final_hidden_states
return sparse_expert_output + shared_expert_output


class MoEDecoderLayer(LlamaDecoderLayer):
Expand Down Expand Up @@ -1104,6 +985,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("experts.w1", "experts.fc1.weight", None),
("experts.w2", "experts.fc2.weight", None),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
Expand Down

0 comments on commit 4d598d1

Please sign in to comment.