Skip to content

Commit

Permalink
[Model] Add support for OLMo architecture (#3046)
Browse files Browse the repository at this point in the history
This PR add support for OLMo architecture.

Additional support: add support for clip-qkv.

Test: already tested on android(pixel 4) and cuda(setting tensor_parallel_shrads=2)
  • Loading branch information
Lanssi authored Dec 14, 2024
1 parent 86cf3f7 commit 385cef2
Show file tree
Hide file tree
Showing 8 changed files with 902 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/mlc_llm/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
llava,
mistral,
oasst,
olmo,
orion,
phi,
qwen2,
Expand Down
28 changes: 28 additions & 0 deletions python/mlc_llm/conversation_template/olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""OLMo default templates"""

from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders

from .registry import ConvTemplateRegistry

# Note that eos_token id is "50279" both in Allenai and AMD version.
# So use the number instead of text.
# Allenai version chat_template and eos_token:
# https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json
# AMD version chat_template and eos_token:
# https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json
ConvTemplateRegistry.register_conv_template(
Conversation(
name="olmo",
system_template=f"{MessagePlaceholders.SYSTEM.value}",
system_message="",
system_prefix_token_ids=[50279],
roles={
"user": "<|user|>",
"assistant": "<|assistant|>",
},
seps=["\n"],
role_content_sep="\n",
role_empty_sep="\n",
stop_token_ids=[50279],
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"aya-23",
"deepseek_v2",
"deepseek",
"olmo",
}
18 changes: 18 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .minicpm import minicpm_loader, minicpm_model, minicpm_quantization
from .mistral import mistral_loader, mistral_model, mistral_quantization
from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization
from .olmo import olmo_loader, olmo_model, olmo_quantization
from .orion import orion_loader, orion_model, orion_quantization
from .phi import phi_loader, phi_model, phi_quantization
from .phi3 import phi3_loader, phi3_model, phi3_quantization
Expand Down Expand Up @@ -532,4 +533,21 @@ class Model:
"ft-quant": deepseek_quantization.ft_quant,
},
),
"olmo": Model(
name="olmo",
model=olmo_model.OLMoForCausalLM,
config=olmo_model.OLMoConfig,
source={
"huggingface-torch": olmo_loader.huggingface,
"huggingface-safetensor": olmo_loader.huggingface,
"awq": olmo_loader.awq,
},
quantize={
"no-quant": olmo_quantization.no_quant,
"group-quant": olmo_quantization.group_quant,
"ft-quant": olmo_quantization.ft_quant,
"awq": olmo_quantization.awq_quant,
"per-tensor-quant": olmo_quantization.per_tensor_quant,
},
),
}
Empty file.
172 changes: 172 additions & 0 deletions python/mlc_llm/model/olmo/olmo_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
This file specifies how MLC's OLMo parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .olmo_model import OLMoConfig, OLMoForCausalLM
from .olmo_quantization import awq_quant


def huggingface(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.
Parameters
----------
model_config : OLMoConfig
The configuration of the OLMo model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = OLMoForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.qkv_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
mlc_name = f"{mlp}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping


def awq(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of AWQ parameters.
Parameters
----------
model_config : OLMoConfig
The configuration of the OLMo model.
quantization : Quantization
The quantization configuration.
Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to AWQ.
"""
model, _ = awq_quant(model_config, quantization)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(), # type: ignore[attr-defined]
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
for quantize_suffix in ["qweight", "qzeros", "scales"]:
mlc_name = f"{attn}.qkv_proj.{quantize_suffix}"
assert mlc_name in named_parameters
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.{quantize_suffix}",
f"{attn}.k_proj.{quantize_suffix}",
f"{attn}.v_proj.{quantize_suffix}",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate(
[q, k, v],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)

# Concat gate and up in MLP
mlp = f"model.layers.{i}.mlp"
for quantize_suffix in ["qweight", "qzeros", "scales"]:
mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}"
assert mlc_name in named_parameters
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.{quantize_suffix}",
f"{mlp}.up_proj.{quantize_suffix}",
],
functools.partial(
lambda gate, up, dtype: np.concatenate(
[gate, up],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)

# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),
)
return mapping
Loading

0 comments on commit 385cef2

Please sign in to comment.