Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 20, 2024
1 parent d525446 commit 471a510
Showing 1 changed file with 0 additions and 27 deletions.
27 changes: 0 additions & 27 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,33 +177,6 @@ def _weight_loader_for_w2(self, param: nn.Parameter, loaded_weight: torch.Tensor
else:
param.data.copy_(loaded_weight.transpose(1, 2))

def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
):
if shard_id == "router":
param.data.copy_(loaded_weight)
elif shard_id == "w1":
if self.tp_size > 1:
# the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size)
up, gate = loaded_weight.chunk(2, dim=-1)
up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank]
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank]
up_and_gate = torch.cat(
[up_current_rank, gate_current_rank], dim=-1
).transpose(1, 2)
param.data.copy_(up_and_gate)
else:
param.data.copy_(loaded_weight.transpose(1, 2))
else:
if self.tp_size > 1:
# the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size)
down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[
self.tp_rank
]
param.data.copy_(down_current_rank.transpose(1, 2))
else:
param.data.copy_(loaded_weight.transpose(1, 2))

def forward(self, hidden_states):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)

Expand Down

0 comments on commit 471a510

Please sign in to comment.