Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for OLMo architecture #3046

Merged
merged 8 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mlc_llm/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
llava,
mistral,
oasst,
olmo,
orion,
phi,
qwen2,
Expand Down
28 changes: 28 additions & 0 deletions python/mlc_llm/conversation_template/olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""OLMo default templates"""

from mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders

from .registry import ConvTemplateRegistry

# Note that eos_token id is "50279" both in Allenai and AMD version.
# So use the number instead of text.
# Allenai version chat_template and eos_token:
# https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json
# AMD version chat_template and eos_token:
# https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json
ConvTemplateRegistry.register_conv_template(
Conversation(
name="olmo",
system_template=f"{MessagePlaceholders.SYSTEM.value}",
system_message="",
system_prefix_token_ids=[50279],
roles={
"user": "<|user|>",
"assistant": "<|assistant|>",
},
seps=["\n"],
role_content_sep="\n",
role_empty_sep="\n",
stop_token_ids=[50279],
)
)
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,4 +306,5 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
"aya-23",
"deepseek_v2",
"deepseek",
"olmo",
}
18 changes: 18 additions & 0 deletions python/mlc_llm/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .minicpm import minicpm_loader, minicpm_model, minicpm_quantization
from .mistral import mistral_loader, mistral_model, mistral_quantization
from .mixtral import mixtral_loader, mixtral_model, mixtral_quantization
from .olmo import olmo_loader, olmo_model, olmo_quantization
from .orion import orion_loader, orion_model, orion_quantization
from .phi import phi_loader, phi_model, phi_quantization
from .phi3 import phi3_loader, phi3_model, phi3_quantization
Expand Down Expand Up @@ -532,4 +533,21 @@ class Model:
"ft-quant": deepseek_quantization.ft_quant,
},
),
"olmo": Model(
name="olmo",
model=olmo_model.OLMoForCausalLM,
config=olmo_model.OLMoConfig,
source={
"huggingface-torch": olmo_loader.huggingface,
"huggingface-safetensor": olmo_loader.huggingface,
"awq": olmo_loader.awq,
},
quantize={
"no-quant": olmo_quantization.no_quant,
"group-quant": olmo_quantization.group_quant,
"ft-quant": olmo_quantization.ft_quant,
"awq": olmo_quantization.awq_quant,
"per-tensor-quant": olmo_quantization.per_tensor_quant,
},
),
}
Empty file.
172 changes: 172 additions & 0 deletions python/mlc_llm/model/olmo/olmo_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
This file specifies how MLC's OLMo parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .olmo_model import OLMoConfig, OLMoForCausalLM
from .olmo_quantization import awq_quant


def huggingface(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : OLMoConfig
The configuration of the OLMo model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = OLMoForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.qkv_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# Add gates in MLP
mlp = f"model.layers.{i}.mlp"
mlc_name = f"{mlp}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.weight",
f"{mlp}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)
# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping


def awq(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of AWQ parameters.
Parameters
----------
model_config : OLMoConfig
The configuration of the OLMo model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to AWQ.
"""
model, _ = awq_quant(model_config, quantization)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(), # type: ignore[attr-defined]
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
for quantize_suffix in ["qweight", "qzeros", "scales"]:
mlc_name = f"{attn}.qkv_proj.{quantize_suffix}"
assert mlc_name in named_parameters
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.{quantize_suffix}",
f"{attn}.k_proj.{quantize_suffix}",
f"{attn}.v_proj.{quantize_suffix}",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate(
[q, k, v],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)

# Concat gate and up in MLP
mlp = f"model.layers.{i}.mlp"
for quantize_suffix in ["qweight", "qzeros", "scales"]:
mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}"
assert mlc_name in named_parameters
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.gate_proj.{quantize_suffix}",
f"{mlp}.up_proj.{quantize_suffix}",
],
functools.partial(
lambda gate, up, dtype: np.concatenate(
[gate, up],
axis=1, # AWQ GEMM would transpose the weight
).astype(dtype),
dtype=mlc_param.dtype,
),
)

# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),
)
return mapping
Loading
Loading