Skip to content

Commit

Permalink
[bugfix] fix aria model and add torch.compile (#10645)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 26, 2024
1 parent 6e9ff05 commit 45ac4ff
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 28 deletions.
26 changes: 4 additions & 22 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LlamaModel)
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
is_pp_missing_parameter,
make_layers, maybe_prefix,
maybe_prefix,
merge_multimodal_embeddings)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
Expand Down Expand Up @@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel):
"""

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

# FIXME: this is a hack to disable the compilation of the model
self.do_not_compile = True

self.layers = None

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MoEDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=MoEDecoderLayer)

# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
Expand Down
16 changes: 10 additions & 6 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

import torch
from torch import nn
Expand Down Expand Up @@ -273,7 +273,11 @@ def forward(
@support_torch_compile
class LlamaModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer):
super().__init__()

config = vllm_config.model_config.hf_config
Expand All @@ -299,10 +303,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
Expand Down

0 comments on commit 45ac4ff

Please sign in to comment.