Skip to content

Commit

Permalink
Add fp8 for dbrx (ROCm#231)
Browse files Browse the repository at this point in the history
* add fp8 for dbrx

* linting
  • Loading branch information
charlifu authored Oct 14, 2024
1 parent 1ec8aaf commit 0e0e968
Showing 1 changed file with 43 additions and 27 deletions.
70 changes: 43 additions & 27 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dbrx import DbrxConfig
Expand Down Expand Up @@ -82,33 +83,45 @@ def __init__(

# Define custom weight loader for dbrx model
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
weight_name: str, param_name: str):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
if weight_name.endswith("w2."):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
if weight_name.endswith("w1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 0] = loaded_weight
else:
param_data = loaded_weight
if weight_name.endswith("v1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, shard_size:2 *
shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 1] = loaded_weight
else:
param_data[:] = loaded_weight
if weight_name.endswith("w2"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
else:
param_data[:] = loaded_weight


class DbrxMoE(nn.Module):
Expand Down Expand Up @@ -409,13 +422,13 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

expert_params_mapping = [(
"w13_" if weight_name in ["w1", "v1"] else "w2_",
f"mlp.{weight_name}.",
"w13" if weight_name in ["w1", "v1"] else "w2",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if name.endswith(("w1", "v1", "w2")):
name = name + ".weight"
if name.endswith(("w1", "w2", "v1")):
name = name + "_weight"
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
Expand All @@ -424,11 +437,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
weight_loader(param, loaded_weight, weight_name, name)
break
else:
if is_pp_missing_parameter(name, self):
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit 0e0e968

Please sign in to comment.